pytme 0.2.9__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. pytme-0.2.9.data/scripts/estimate_ram_usage.py +97 -0
  2. pytme-0.2.9.data/scripts/match_template.py +1135 -0
  3. pytme-0.2.9.data/scripts/postprocess.py +622 -0
  4. pytme-0.2.9.data/scripts/preprocess.py +209 -0
  5. pytme-0.2.9.data/scripts/preprocessor_gui.py +1227 -0
  6. pytme-0.2.9.dist-info/METADATA +95 -0
  7. pytme-0.2.9.dist-info/RECORD +119 -0
  8. pytme-0.2.9.dist-info/WHEEL +5 -0
  9. pytme-0.2.9.dist-info/entry_points.txt +6 -0
  10. pytme-0.2.9.dist-info/licenses/LICENSE +153 -0
  11. pytme-0.2.9.dist-info/top_level.txt +3 -0
  12. scripts/__init__.py +0 -0
  13. scripts/estimate_ram_usage.py +97 -0
  14. scripts/match_template.py +1135 -0
  15. scripts/postprocess.py +622 -0
  16. scripts/preprocess.py +209 -0
  17. scripts/preprocessor_gui.py +1227 -0
  18. tests/__init__.py +0 -0
  19. tests/data/Blurring/blob_width18.npy +0 -0
  20. tests/data/Blurring/edgegaussian_sigma3.npy +0 -0
  21. tests/data/Blurring/gaussian_sigma2.npy +0 -0
  22. tests/data/Blurring/hamming_width6.npy +0 -0
  23. tests/data/Blurring/kaiserb_width18.npy +0 -0
  24. tests/data/Blurring/localgaussian_sigma0510.npy +0 -0
  25. tests/data/Blurring/mean_size5.npy +0 -0
  26. tests/data/Blurring/ntree_sigma0510.npy +0 -0
  27. tests/data/Blurring/rank_rank3.npy +0 -0
  28. tests/data/Maps/.DS_Store +0 -0
  29. tests/data/Maps/emd_8621.mrc.gz +0 -0
  30. tests/data/README.md +2 -0
  31. tests/data/Raw/em_map.map +0 -0
  32. tests/data/Structures/.DS_Store +0 -0
  33. tests/data/Structures/1pdj.cif +3339 -0
  34. tests/data/Structures/1pdj.pdb +1429 -0
  35. tests/data/Structures/5khe.cif +3685 -0
  36. tests/data/Structures/5khe.ent +2210 -0
  37. tests/data/Structures/5khe.pdb +2210 -0
  38. tests/data/Structures/5uz4.cif +70548 -0
  39. tests/preprocessing/__init__.py +0 -0
  40. tests/preprocessing/test_compose.py +76 -0
  41. tests/preprocessing/test_frequency_filters.py +178 -0
  42. tests/preprocessing/test_preprocessor.py +136 -0
  43. tests/preprocessing/test_utils.py +79 -0
  44. tests/test_analyzer.py +216 -0
  45. tests/test_backends.py +446 -0
  46. tests/test_density.py +503 -0
  47. tests/test_extensions.py +130 -0
  48. tests/test_matching_cli.py +283 -0
  49. tests/test_matching_data.py +162 -0
  50. tests/test_matching_exhaustive.py +124 -0
  51. tests/test_matching_memory.py +30 -0
  52. tests/test_matching_optimization.py +226 -0
  53. tests/test_matching_utils.py +189 -0
  54. tests/test_orientations.py +175 -0
  55. tests/test_parser.py +33 -0
  56. tests/test_rotations.py +153 -0
  57. tests/test_structure.py +247 -0
  58. tme/__init__.py +6 -0
  59. tme/__version__.py +1 -0
  60. tme/analyzer/__init__.py +2 -0
  61. tme/analyzer/_utils.py +186 -0
  62. tme/analyzer/aggregation.py +577 -0
  63. tme/analyzer/peaks.py +953 -0
  64. tme/backends/__init__.py +171 -0
  65. tme/backends/_cupy_utils.py +734 -0
  66. tme/backends/_jax_utils.py +188 -0
  67. tme/backends/cupy_backend.py +294 -0
  68. tme/backends/jax_backend.py +314 -0
  69. tme/backends/matching_backend.py +1270 -0
  70. tme/backends/mlx_backend.py +241 -0
  71. tme/backends/npfftw_backend.py +583 -0
  72. tme/backends/pytorch_backend.py +430 -0
  73. tme/data/__init__.py +0 -0
  74. tme/data/c48n309.npy +0 -0
  75. tme/data/c48n527.npy +0 -0
  76. tme/data/c48n9.npy +0 -0
  77. tme/data/c48u1.npy +0 -0
  78. tme/data/c48u1153.npy +0 -0
  79. tme/data/c48u1201.npy +0 -0
  80. tme/data/c48u1641.npy +0 -0
  81. tme/data/c48u181.npy +0 -0
  82. tme/data/c48u2219.npy +0 -0
  83. tme/data/c48u27.npy +0 -0
  84. tme/data/c48u2947.npy +0 -0
  85. tme/data/c48u3733.npy +0 -0
  86. tme/data/c48u4749.npy +0 -0
  87. tme/data/c48u5879.npy +0 -0
  88. tme/data/c48u7111.npy +0 -0
  89. tme/data/c48u815.npy +0 -0
  90. tme/data/c48u83.npy +0 -0
  91. tme/data/c48u8649.npy +0 -0
  92. tme/data/c600v.npy +0 -0
  93. tme/data/c600vc.npy +0 -0
  94. tme/data/metadata.yaml +80 -0
  95. tme/data/quat_to_numpy.py +42 -0
  96. tme/data/scattering_factors.pickle +0 -0
  97. tme/density.py +2263 -0
  98. tme/extensions.cpython-311-darwin.so +0 -0
  99. tme/external/bindings.cpp +332 -0
  100. tme/filters/__init__.py +6 -0
  101. tme/filters/_utils.py +311 -0
  102. tme/filters/bandpass.py +230 -0
  103. tme/filters/compose.py +81 -0
  104. tme/filters/ctf.py +393 -0
  105. tme/filters/reconstruction.py +160 -0
  106. tme/filters/wedge.py +542 -0
  107. tme/filters/whitening.py +191 -0
  108. tme/matching_data.py +863 -0
  109. tme/matching_exhaustive.py +497 -0
  110. tme/matching_optimization.py +1311 -0
  111. tme/matching_scores.py +1183 -0
  112. tme/matching_utils.py +1188 -0
  113. tme/memory.py +337 -0
  114. tme/orientations.py +598 -0
  115. tme/parser.py +685 -0
  116. tme/preprocessor.py +1329 -0
  117. tme/rotations.py +350 -0
  118. tme/structure.py +1864 -0
  119. tme/types.py +13 -0
tme/structure.py ADDED
@@ -0,0 +1,1864 @@
1
+ """ Implements class Structure to represent atomic structures.
2
+
3
+ Copyright (c) 2023 European Molecular Biology Laboratory
4
+
5
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ """
7
+
8
+ import warnings
9
+ from copy import deepcopy
10
+ from itertools import groupby
11
+ from dataclasses import dataclass
12
+ from collections import namedtuple
13
+ from typing import List, Dict, Tuple
14
+ from os.path import splitext, basename
15
+
16
+ import numpy as np
17
+
18
+ from .types import NDArray
19
+ from .rotations import align_to_axis
20
+ from .preprocessor import atom_profile, Preprocessor
21
+ from .parser import PDBParser, MMCIFParser, GROParser
22
+ from .matching_utils import rigid_transform, minimum_enclosing_box
23
+
24
+ __all__ = ["Structure"]
25
+
26
+
27
+ @dataclass(repr=False)
28
+ class Structure:
29
+ """
30
+ Represents atomic structures per the Protein Data Bank (PDB) specification.
31
+
32
+ Examples
33
+ --------
34
+ The following achieves the definition of a :py:class:`Structure` instance
35
+
36
+ >>> from tme import Structure
37
+ >>> structure = Structure(
38
+ >>> record_type=["ATOM", "ATOM", "ATOM"],
39
+ >>> atom_serial_number=[0, 1, 2] ,
40
+ >>> atom_name=["C", "N", "H"],
41
+ >>> atom_coordinate=[[30,15,10], [35, 20, 15], [35,25,20]],
42
+ >>> alternate_location_indicator=[".", ".", "."],
43
+ >>> residue_name=["GLY", "GLY", "HIS"],
44
+ >>> chain_identifier=["A", "A", "B"],
45
+ >>> residue_sequence_number=[0, 0, 1],
46
+ >>> code_for_residue_insertion=["?", "?", "?"],
47
+ >>> occupancy=[0, 0, 0],
48
+ >>> temperature_factor=[0, 0, 0],
49
+ >>> segment_identifier=["1", "1", "1"],
50
+ >>> element_symbol=["C", "N", "C"],
51
+ >>> charge=["?", "?", "?"],
52
+ >>> metadata={},
53
+ >>> )
54
+ >>> structure
55
+ Unique Chains: A-B, Atom Range: 0-2 [N = 3], Residue Range: 0-1 [N = 3]
56
+
57
+ :py:class:`Structure` instances support a range of subsetting operations based on
58
+ atom indices
59
+
60
+ >>> structure[1]
61
+ Unique Chains: A, Atom Range: 1-1 [N = 1], Residue Range: 0-0 [N = 1]
62
+ >>> structure[(False, False, True)]
63
+ Unique Chains: B, Atom Range: 2-2 [N = 1], Residue Range: 1-1 [N = 1]
64
+ >>> structure[(1,2)]
65
+ Unique Chains: A-B, Atom Range: 1-2 [N = 2], Residue Range: 0-1 [N = 2]
66
+
67
+ They can be written to disk in a range of formats using :py:meth:`Structure.to_file`
68
+
69
+ >>> structure.to_file("test.pdb") # Writes a PDB file to disk
70
+ >>> structure.to_file("test.cif") # Writes a mmCIF file to disk
71
+
72
+ New instances can be created from a range of formats using
73
+ :py:meth:`Structure.from_file`
74
+
75
+ >>> Structure.from_file("test.pdb") # Reads PDB file from disk
76
+ Unique Chains: A-B, Atom Range: 0-2 [N = 3], Residue Range: 0-1 [N = 3]
77
+ >>> Structure.from_file("test.cif") # Reads mmCIF file from disk
78
+ Unique Chains: A-B, Atom Range: 0-2 [N = 3], Residue Range: 0-1 [N = 3]
79
+
80
+ Class instances can be discretized on grids and converted to
81
+ :py:class:`tme.density.Density` instances using :py:meth:`Structure.to_volume`
82
+ or :py:meth:`tme.density.Density.from_structure`.
83
+
84
+ >>> volume, origin, sampling_rate = structure.to_volume(shape=(50,40,30))
85
+
86
+ References
87
+ ----------
88
+ .. [1] https://www.cgl.ucsf.edu/chimera/docs/UsersGuide/tutorials/pdbintro.html
89
+ .. [2] https://www.ccp4.ac.uk/html/mmcifformat.html
90
+
91
+ """
92
+
93
+ #: Array of record types, e.g.ATOM.
94
+ record_type: NDArray
95
+
96
+ #: Array of serial numbers.
97
+ atom_serial_number: NDArray
98
+
99
+ #: Array of atom names.
100
+ atom_name: NDArray
101
+
102
+ #: Array of x,y,z atom coordinates.
103
+ atom_coordinate: NDArray
104
+
105
+ #: Array of alternate location indices.
106
+ alternate_location_indicator: NDArray
107
+
108
+ #: Array of residue names.
109
+ residue_name: NDArray
110
+
111
+ #: Array of chain identifiers.
112
+ chain_identifier: NDArray
113
+
114
+ #: Array of residue ids.
115
+ residue_sequence_number: NDArray
116
+
117
+ #: Array of insertion information.
118
+ code_for_residue_insertion: NDArray
119
+
120
+ #: Array of occupancy factors.
121
+ occupancy: NDArray
122
+
123
+ #: Array of B-factors.
124
+ temperature_factor: NDArray
125
+
126
+ #: Array of segment identifiers.
127
+ segment_identifier: NDArray
128
+
129
+ #: Array of element symbols.
130
+ element_symbol: NDArray
131
+
132
+ #: Array of charges.
133
+ charge: NDArray
134
+
135
+ #: Metadata dictionary.
136
+ metadata: dict
137
+
138
+ def __post_init__(self, *args, **kwargs):
139
+ """
140
+ Initialize the structure and populate header metadata.
141
+
142
+ Raises
143
+ ------
144
+ ValueError
145
+ If NDArray attributes does not match the number of atoms.
146
+ """
147
+ for attribute in self.__dict__:
148
+ value = getattr(self, attribute)
149
+ target_type = self.__annotations__.get(attribute, None)
150
+ if target_type == NDArray:
151
+ setattr(self, attribute, np.atleast_1d(np.array(value)))
152
+
153
+ n_atoms = self.atom_coordinate.shape[0]
154
+ for attribute in self.__dict__:
155
+ value = getattr(self, attribute)
156
+ if not isinstance(value, np.ndarray):
157
+ continue
158
+ if value.shape[0] != n_atoms:
159
+ raise ValueError(
160
+ f"Expected shape of {attribute}: {n_atoms}, got {value.shape[0]}."
161
+ )
162
+
163
+ self._elements = _Elements()
164
+ self.metadata = self._populate_metadata(self.metadata)
165
+
166
+ def __getitem__(self, indices: List[int]) -> "Structure":
167
+ """
168
+ Get a Structure instance for specified indices.
169
+
170
+ Parameters
171
+ ----------
172
+ indices : Union[int, bool, NDArray]
173
+ The indices to get.
174
+
175
+ Returns
176
+ -------
177
+ Structure
178
+ The Structure instance for the given indices.
179
+ """
180
+ if type(indices) in (int, bool):
181
+ indices = (indices,)
182
+
183
+ indices = np.asarray(indices)
184
+ attributes = (
185
+ "record_type",
186
+ "atom_serial_number",
187
+ "atom_name",
188
+ "atom_coordinate",
189
+ "alternate_location_indicator",
190
+ "residue_name",
191
+ "chain_identifier",
192
+ "residue_sequence_number",
193
+ "code_for_residue_insertion",
194
+ "occupancy",
195
+ "temperature_factor",
196
+ "segment_identifier",
197
+ "element_symbol",
198
+ "charge",
199
+ )
200
+ kwargs = {attr: getattr(self, attr)[indices] for attr in attributes}
201
+ ret = self.__class__(**kwargs, metadata={})
202
+ return ret
203
+
204
+ def __repr__(self):
205
+ """
206
+ Return a string representation of the Structure.
207
+ """
208
+ unique_chains = "-".join(
209
+ [
210
+ ",".join([str(x) for x in entity])
211
+ for entity in self.metadata["unique_chains"]
212
+ ]
213
+ )
214
+ min_atom = np.min(self.atom_serial_number)
215
+ max_atom = np.max(self.atom_serial_number)
216
+ n_atom = self.atom_serial_number.size
217
+
218
+ min_residue = np.min(self.residue_sequence_number)
219
+ max_residue = np.max(self.residue_sequence_number)
220
+ n_residue = np.unique(self.residue_sequence_number).size
221
+
222
+ repr_str = (
223
+ f"Structure object at {id(self)}\n"
224
+ f"Unique Chains: {unique_chains}, "
225
+ f"Atom Range: {min_atom}-{max_atom} [N = {n_atom}], "
226
+ f"Residue Range: {min_residue}-{max_residue} [N = {n_residue}]"
227
+ )
228
+ return repr_str
229
+
230
+ def copy(self) -> "Structure":
231
+ """
232
+ Returns a copy of the Structure instance.
233
+
234
+ Returns
235
+ -------
236
+ :py:class:`Structure`
237
+ The copied Structure instance.
238
+
239
+ Examples
240
+ --------
241
+ >>> import numpy as np
242
+ >>> structure_copy = structure.copy()
243
+ >>> np.allclose(structure_copy.atom_coordinate, structure.atom_coordinate)
244
+ True
245
+ """
246
+ return deepcopy(self)
247
+
248
+ def _populate_metadata(self, metadata: Dict = {}) -> Dict:
249
+ """
250
+ Populate the metadata dictionary with the data from the Structure instance.
251
+
252
+ Parameters
253
+ ----------
254
+ metadata : dict, optional
255
+ The initial metadata dictionary, by default {}.
256
+
257
+ Returns
258
+ -------
259
+ dict
260
+ The populated metadata dictionary.
261
+ """
262
+ metadata["weight"] = np.sum(
263
+ [self._elements[atype].atomic_weight for atype in self.element_symbol]
264
+ )
265
+
266
+ label, idx, chain = np.unique(
267
+ self.chain_identifier, return_inverse=True, return_index=True
268
+ )
269
+ chain_weight = np.bincount(
270
+ chain,
271
+ [self._elements[atype].atomic_weight for atype in self.element_symbol],
272
+ )
273
+ labels = self.chain_identifier[idx]
274
+ metadata["chain_weight"] = {key: val for key, val in zip(labels, chain_weight)}
275
+
276
+ # Group non-unique chains in separate lists in metadata["unique_chains"]
277
+ metadata["unique_chains"], temp = [], {}
278
+ for chain_label in label:
279
+ index = len(metadata["unique_chains"])
280
+ chain_sequence = "".join(
281
+ [
282
+ str(y)
283
+ for y in self.element_symbol[
284
+ np.where(self.chain_identifier == chain_label)
285
+ ]
286
+ ]
287
+ )
288
+ if chain_sequence not in temp:
289
+ temp[chain_sequence] = index
290
+ metadata["unique_chains"].append([chain_label])
291
+ continue
292
+ idx = temp.get(chain_sequence)
293
+ metadata["unique_chains"][idx].append(chain_label)
294
+
295
+ filtered_data = [
296
+ (label, integer)
297
+ for label, integer in zip(
298
+ self.chain_identifier, self.residue_sequence_number
299
+ )
300
+ ]
301
+ filtered_data = sorted(filtered_data, key=lambda x: x[0])
302
+ metadata["chain_range"] = {}
303
+ for label, values in groupby(filtered_data, key=lambda x: x[0]):
304
+ values = [int(x[1]) for x in values]
305
+ metadata["chain_range"][label] = (min(values), max(values))
306
+
307
+ return metadata
308
+
309
+ @classmethod
310
+ def from_file(
311
+ cls,
312
+ filename: str,
313
+ keep_non_atom_records: bool = False,
314
+ filter_by_elements: set = None,
315
+ filter_by_residues: set = None,
316
+ ) -> "Structure":
317
+ """
318
+ Reads an atomic structure file and into a :py:class:`Structure` instance.
319
+
320
+ Parameters
321
+ ----------
322
+ filename : str
323
+ Input file. Supported extensions are:
324
+
325
+ +------+-------------------------------------------------------------+
326
+ | .pdb | Reads a PDB file |
327
+ +------+-------------------------------------------------------------+
328
+ | .cif | Reads an mmCIF file |
329
+ +------+-------------------------------------------------------------+
330
+ | .gro | Reads a Gromos87 formated file |
331
+ +------+-------------------------------------------------------------+
332
+ keep_non_atom_records : bool, optional
333
+ Wheter to keep residues that are not labelled ATOM.
334
+ filter_by_elements: set, optional
335
+ Which elements to keep. Default corresponds to all elements.
336
+ filter_by_residues: set, optional
337
+ Which residues to keep. Default corresponds to all residues.
338
+
339
+ Raises
340
+ ------
341
+ NotImplementedError
342
+ If the extension is not supported.
343
+
344
+ Returns
345
+ -------
346
+ :py:class:`Structure`
347
+ Structure instance representing the read in file.
348
+
349
+ Examples
350
+ --------
351
+ >>> from importlib_resources import files
352
+ >>> from tme import Structure
353
+ >>> fname = str(files("tests.data").joinpath("Structures/5khe.cif"))
354
+ >>> structure = Structure.from_file(filename=fname)
355
+ >>> structure
356
+ Unique Chains: A-B, Atom Range: 1-1564 [N = 1564], Residue Range: 142-239 [N = 1564]
357
+
358
+ We can include non ATOM entries and restrict the considered elements
359
+ and residues
360
+
361
+ >>> structure = Structure.from_file(
362
+ >>> filename=fname,
363
+ >>> keep_non_atom_records=True,
364
+ >>> filter_by_elements = {"C"},
365
+ >>> filter_by_residues = {"GLY"},
366
+ >>> )
367
+ >>> structure
368
+ Unique Chains: A,B, Atom Range: 96-1461 [N = 44], Residue Range: 154-228 [N = 44]
369
+ """
370
+ _, file_extension = splitext(basename(filename.lower()))
371
+ _formats = {
372
+ ".pdb": cls._load_pdb,
373
+ ".cif": cls._load_mmcif,
374
+ ".gro": cls._load_gro,
375
+ }
376
+ func = _formats.get(file_extension)
377
+ if func is None:
378
+ formats = ",".join([f"'{x}'" for x in _formats.keys()])
379
+ raise NotImplementedError(
380
+ f"Files with extension {file_extension} are not supported. "
381
+ f"Supported filetypes are {formats}."
382
+ )
383
+
384
+ data = func(cls, filename)
385
+ keep = np.ones(data["element_symbol"].size, dtype=bool)
386
+ if filter_by_elements:
387
+ keep = np.logical_and(
388
+ keep,
389
+ np.isin(data["element_symbol"], np.array(list(filter_by_elements))),
390
+ )
391
+ if filter_by_residues:
392
+ keep = np.logical_and(
393
+ keep, np.isin(data["residue_name"], np.array(list(filter_by_residues)))
394
+ )
395
+ if not keep_non_atom_records:
396
+ keep = np.logical_and(keep, data["record_type"] == "ATOM")
397
+
398
+ for key in data:
399
+ if keep.sum() == keep.size:
400
+ break
401
+
402
+ if key == "metadata":
403
+ continue
404
+ if isinstance(data[key], np.ndarray):
405
+ data[key] = data[key][keep]
406
+ else:
407
+ data[key] = [x for x, flag in zip(data[key], keep) if flag]
408
+
409
+ data["metadata"]["filepath"] = filename
410
+
411
+ return cls(**data)
412
+
413
+ @staticmethod
414
+ def _convert_dtypes(data: Dict[str, List], mapping: Dict):
415
+ """
416
+ Convert key values in data according to mapping.
417
+
418
+ Parameters
419
+ ----------
420
+ data : Dict
421
+ Mapping of keys to list of values
422
+ mapping : Dict
423
+ Mapping of key in return dict to (key, dtype) in data.
424
+
425
+ Returns
426
+ -------
427
+ dict
428
+ Key-value map using key-dtype pairs in mapping on data.
429
+ """
430
+ out = {}
431
+ max_len = max([len(t) for t in data.values() if hasattr(t, "__len__")])
432
+
433
+ missing_keys = set()
434
+ for out_key, (inner_key, dtype) in mapping.items():
435
+ default = "." if dtype is str else 0
436
+ if inner_key in data:
437
+ continue
438
+ missing_keys.add(inner_key)
439
+ out[out_key] = np.repeat(default, max_len).astype(dtype)
440
+
441
+ if len(missing_keys):
442
+ msg = ", ".join([f"'{x}'" for x in missing_keys])
443
+ warnings.warn(
444
+ f"Missing keys: ({msg}) in data - filling with default value."
445
+ )
446
+
447
+ for out_key, (inner_key, dtype) in mapping.items():
448
+ default = "." if dtype is str else 0
449
+
450
+ # Avoid modifying input dictionaries
451
+ if inner_key in missing_keys:
452
+ continue
453
+
454
+ out_data = data[inner_key]
455
+ if isinstance(data[inner_key][0], str):
456
+ out_data = [str(x).strip() for x in data[inner_key]]
457
+
458
+ out_data = np.asarray(out_data)
459
+ if dtype is int:
460
+ out_data = np.where(out_data == ".", "0", out_data)
461
+ elif dtype == "base-36":
462
+ dtype = int
463
+ base36_offset = int("A0000", 36) - 100000
464
+ out_data = np.where(
465
+ np.char.isdigit(out_data),
466
+ out_data,
467
+ np.vectorize(lambda x: int(x, 36) - base36_offset)(out_data),
468
+ )
469
+ try:
470
+ out[out_key] = np.asarray(out_data, dtype=dtype)
471
+ except ValueError:
472
+ print(
473
+ f"Converting {out_key} to {dtype} failed. Setting {out_key} to {default}."
474
+ )
475
+ out[out_key] = np.repeat(default, max_len).astype(dtype)
476
+ return out
477
+
478
+ def _load_mmcif(self, filename: str) -> Dict:
479
+ """
480
+ Parses a macromolecular Crystallographic Information File (mmCIF)
481
+ and returns the data in a dictionary format.
482
+
483
+ Parameters
484
+ ----------
485
+ filename : str
486
+ The filename of the mmCIF to load.
487
+
488
+ Returns
489
+ -------
490
+ dict
491
+ A dictionary of numpy arrays. Keys are the names of the PDB
492
+ coordinate section. In addition, some details about the parsed
493
+ structure are included. In case of conversion failure, the failing
494
+ attribute is set to 0 if its supposed to be an integer value.
495
+ """
496
+ result = MMCIFParser(filename)
497
+
498
+ atom_site_mapping = {
499
+ "record_type": ("group_PDB", str),
500
+ "atom_serial_number": ("id", int),
501
+ "atom_name": ("label_atom_id", str),
502
+ "alternate_location_indicator": ("label_alt_id", str),
503
+ "residue_name": ("label_comp_id", str),
504
+ # "chain_identifier": ("auth_asym_id", str),
505
+ "chain_identifier": ("label_asym_id", str),
506
+ "residue_sequence_number": ("label_seq_id", int),
507
+ "code_for_residue_insertion": ("pdbx_PDB_ins_code", str),
508
+ "occupancy": ("occupancy", float),
509
+ "temperature_factor": ("B_iso_or_equiv", float),
510
+ "segment_identifier": ("label_entity_id", str),
511
+ "element_symbol": ("type_symbol", str),
512
+ "charge": ("pdbx_formal_charge", str),
513
+ }
514
+
515
+ out = self._convert_dtypes(result["atom_site"], atom_site_mapping)
516
+ number_entries = len(max(out.values(), key=len))
517
+ for key, value in out.items():
518
+ if value.size != 1:
519
+ continue
520
+ out[key] = np.repeat(value, number_entries // value.size)
521
+
522
+ out["metadata"] = {}
523
+ out["atom_coordinate"] = np.transpose(
524
+ np.array(
525
+ [
526
+ result["atom_site"]["Cartn_x"],
527
+ result["atom_site"]["Cartn_y"],
528
+ result["atom_site"]["Cartn_z"],
529
+ ],
530
+ dtype=np.float32,
531
+ )
532
+ )
533
+
534
+ detail_mapping = {
535
+ "resolution": ("em_3d_reconstruction", "resolution", np.nan),
536
+ "resolution_method": ("em_3d_reconstruction", "resolution_method", np.nan),
537
+ "method": ("exptl", "method", np.nan),
538
+ "electron_source": ("em_imaging", "electron_source", np.nan),
539
+ "illumination_mode": ("em_imaging", "illumination_mode", np.nan),
540
+ "microscope_model": ("em_imaging", "microscope_model", np.nan),
541
+ }
542
+ for out_key, (base_key, inner_key, default) in detail_mapping.items():
543
+ if base_key not in result:
544
+ continue
545
+ out["metadata"][out_key] = result[base_key].get(inner_key, default)
546
+
547
+ return out
548
+
549
+ def _load_pdb(self, filename: str) -> Dict:
550
+ """
551
+ Parses a Protein Data Bank (PDB) file and returns the data
552
+ in a dictionary format.
553
+
554
+ Parameters
555
+ ----------
556
+ filename : str
557
+ The filename of the PDB file to load.
558
+
559
+ Returns
560
+ -------
561
+ dict
562
+ A dictionary of numpy arrays. Keys are the names of the PDB
563
+ coordinate section. In addition, some details about the parsed
564
+ structure are included. In case of conversion failure, the failing
565
+ attribute is set to 0 if its supposed to be an integer value.
566
+ """
567
+ result = PDBParser(filename)
568
+
569
+ atom_site_mapping = {
570
+ "record_type": ("record_type", str),
571
+ "atom_serial_number": ("atom_serial_number", "base-36"),
572
+ "atom_name": ("atom_name", str),
573
+ "alternate_location_indicator": ("alternate_location_indicator", str),
574
+ "residue_name": ("residue_name", str),
575
+ "chain_identifier": ("chain_identifier", str),
576
+ "residue_sequence_number": ("residue_sequence_number", int),
577
+ "code_for_residue_insertion": ("code_for_residue_insertion", str),
578
+ "occupancy": ("occupancy", float),
579
+ "temperature_factor": ("temperature_factor", float),
580
+ "segment_identifier": ("segment_identifier", str),
581
+ "element_symbol": ("element_symbol", str),
582
+ "charge": ("charge", str),
583
+ }
584
+
585
+ out = self._convert_dtypes(result, atom_site_mapping)
586
+
587
+ out["metadata"] = result["details"]
588
+ out["atom_coordinate"] = np.array(result["atom_coordinate"], dtype=np.float32)
589
+
590
+ return out
591
+
592
+ def _load_gro(self, filename):
593
+ result = GROParser(filename)
594
+
595
+ atom_site_mapping = {
596
+ "record_type": ("record_type", str),
597
+ "atom_serial_number": ("atom_number", int),
598
+ "atom_name": ("atom_name", str),
599
+ "alternate_location_indicator": ("label_alt_id", str),
600
+ "residue_name": ("residue_name", str),
601
+ "chain_identifier": ("segment_identifier", str),
602
+ "residue_sequence_number": ("residue_number", int),
603
+ "code_for_residue_insertion": ("pdbx_PDB_ins_code", str),
604
+ "occupancy": ("occupancy", float),
605
+ "temperature_factor": ("B_iso_or_equiv", float),
606
+ "segment_identifier": ("segment_identifier", str),
607
+ "element_symbol": ("type_symbol", str),
608
+ "charge": ("pdbx_formal_charge", str),
609
+ }
610
+
611
+ out = self._convert_dtypes(result, atom_site_mapping)
612
+
613
+ unique_chains = np.unique(out["segment_identifier"])
614
+ if len(unique_chains) > 1:
615
+ warnings.warn(
616
+ "Multiple GRO files detected - treating them as a single Structure. "
617
+ "GRO file number is given by segment_identifier according to the "
618
+ "input file. Note: You need to subset the Structure to operate on "
619
+ "individual GRO files."
620
+ )
621
+
622
+ mkeys = ("title", "time", "box_vectors")
623
+ out["metadata"] = {key: result.get(key) for key in mkeys}
624
+ out["atom_coordinate"] = np.asarray(result["atom_coordinate"], dtype=np.float32)
625
+ return out
626
+
627
+ def to_file(self, filename: str) -> None:
628
+ """
629
+ Writes the :py:class:`Structure` instance to disk.
630
+
631
+ Parameters
632
+ ----------
633
+ filename : str
634
+ The name of the file to be created. Supported extensions are
635
+
636
+ +------+-------------------------------------------------------------+
637
+ | .pdb | Creates a PDB file |
638
+ +------+-------------------------------------------------------------+
639
+ | .cif | Creates an mmCIF file |
640
+ +------+-------------------------------------------------------------+
641
+ | .gro | Creates an Gromos87 file |
642
+ +------+-------------------------------------------------------------+
643
+ Raises
644
+ ------
645
+ NotImplementedError
646
+ If the extension is not supported.
647
+
648
+ Examples
649
+ --------
650
+ >>> from importlib_resources import files
651
+ >>> from tempfile import NamedTemporaryFile
652
+ >>> from tme import Structure
653
+ >>> fname = str(files("tests.data").joinpath("Structures/5khe.cif"))
654
+ >>> oname = NamedTemporaryFile().name
655
+ >>> structure = Structure.from_file(filename=fname)
656
+ >>> structure.to_file(f"{oname}.cif") # Writes an mmCIF file to disk
657
+ >>> structure.to_file(f"{oname}.pdb") # Writes a PDB file to disk
658
+
659
+ """
660
+ _, file_extension = splitext(basename(filename.lower()))
661
+ _formats = {
662
+ ".pdb": self._write_pdb,
663
+ ".cif": self._write_mmcif,
664
+ ".gro": self._write_gro,
665
+ }
666
+ func = _formats.get(file_extension)
667
+ if func is None:
668
+ formats = ",".join([f"'{x}'" for x in _formats.keys()])
669
+ raise NotImplementedError(
670
+ f"Files with extension {file_extension} are not supported. "
671
+ f"Supported filetypes are {formats}."
672
+ )
673
+
674
+ if np.any(np.vectorize(len)(self.chain_identifier) > 2):
675
+ warnings.warn("Chain identifiers longer than one will be shortened.")
676
+
677
+ if self.atom_coordinate.shape[0] > 10**5 and func == self._write_pdb:
678
+ warnings.warn(
679
+ "The structure contains more than 100,000 atoms. Consider using mmcif."
680
+ )
681
+
682
+ with open(filename, mode="w", encoding="utf-8") as ofile:
683
+ ofile.write(func())
684
+
685
+ def _write_pdb(self) -> str:
686
+ """
687
+ Returns a PDB string representation of the structure instance.
688
+
689
+ Returns
690
+ -------
691
+ str
692
+ String containing PDB file coordine lines.
693
+ """
694
+ data_out = []
695
+ for index in range(self.atom_coordinate.shape[0]):
696
+ x, y, z = self.atom_coordinate[index, :]
697
+ line = list(" " * 80)
698
+ line[0:6] = f"{self.record_type[index]:<6}"
699
+ line[6:11] = f"{self.atom_serial_number[index]:>5}"
700
+ line[12:16] = f"{self.atom_name[index]:<4}"
701
+ line[16] = f"{self.alternate_location_indicator[index]:<1}"
702
+ line[17:20] = f"{self.residue_name[index]:<3}"
703
+ line[21] = f"{self.chain_identifier[index][0]:<1}"
704
+ line[22:26] = f"{self.residue_sequence_number[index]:>4}"
705
+ line[26] = f"{self.code_for_residue_insertion[index]:<1}"
706
+ line[30:38] = f"{x:>8.3f}"
707
+ line[38:46] = f"{y:>8.3f}"
708
+ line[46:54] = f"{z:>8.3f}"
709
+ line[54:60] = f"{self.occupancy[index]:>6.2f}"
710
+ line[60:66] = f"{self.temperature_factor[index]:>6.2f}"
711
+ line[72:76] = f"{self.segment_identifier[index]:>4}"
712
+ line[76:78] = f"{self.element_symbol[index]:<2}"
713
+ line[78:80] = f"{self.charge[index]:>2}"
714
+ data_out.append("".join(line))
715
+ data_out.append("END")
716
+ data_out = "\n".join(data_out)
717
+ return data_out
718
+
719
+ def _write_mmcif(self) -> str:
720
+ """
721
+ Returns a MMCIF string representation of the structure instance.
722
+
723
+ Returns
724
+ -------
725
+ str
726
+ String containing MMCIF file coordinate lines.
727
+ """
728
+ model_num, entity_id = 1, 1
729
+ data = {
730
+ "group_PDB": [],
731
+ "id": [],
732
+ "type_symbol": [],
733
+ "label_atom_id": [],
734
+ "label_alt_id": [],
735
+ "label_comp_id": [],
736
+ "label_asym_id": [],
737
+ "label_entity_id": [],
738
+ "label_seq_id": [],
739
+ "pdbx_PDB_ins_code": [],
740
+ "Cartn_x": [],
741
+ "Cartn_y": [],
742
+ "Cartn_z": [],
743
+ "occupancy": [],
744
+ "B_iso_or_equiv": [],
745
+ "pdbx_formal_charge": [],
746
+ "auth_seq_id": [],
747
+ "auth_comp_id": [],
748
+ "auth_asym_id": [],
749
+ "auth_atom_id": [],
750
+ "pdbx_PDB_model_num": [],
751
+ }
752
+
753
+ for index in range(self.atom_coordinate.shape[0]):
754
+ x, y, z = self.atom_coordinate[index, :]
755
+ data["group_PDB"].append(self.record_type[index])
756
+ data["id"].append(str(self.atom_serial_number[index]))
757
+ data["type_symbol"].append(self.element_symbol[index])
758
+ data["label_atom_id"].append(self.atom_name[index])
759
+ data["label_alt_id"].append(self.alternate_location_indicator[index])
760
+ data["label_comp_id"].append(self.residue_name[index])
761
+ data["label_asym_id"].append(self.chain_identifier[index][0])
762
+ data["label_entity_id"].append(str(entity_id))
763
+ data["label_seq_id"].append(str(self.residue_sequence_number[index]))
764
+ data["pdbx_PDB_ins_code"].append(self.code_for_residue_insertion[index])
765
+ data["Cartn_x"].append(f"{x:.3f}")
766
+ data["Cartn_y"].append(f"{y:.3f}")
767
+ data["Cartn_z"].append(f"{z:.3f}")
768
+ data["occupancy"].append(f"{self.occupancy[index]:.2f}")
769
+ data["B_iso_or_equiv"].append(f"{self.temperature_factor[index]:.2f}")
770
+ data["pdbx_formal_charge"].append(self.charge[index])
771
+ data["auth_seq_id"].append(str(self.residue_sequence_number[index]))
772
+ data["auth_comp_id"].append(self.residue_name[index])
773
+ data["auth_asym_id"].append(self.chain_identifier[index][0])
774
+ data["auth_atom_id"].append(self.atom_name[index])
775
+ data["pdbx_PDB_model_num"].append(str(model_num))
776
+
777
+ output_data = {"atom_site": data}
778
+ original_file = self.metadata.get("filepath", "")
779
+ try:
780
+ new_data = {k: v for k, v in MMCIFParser(original_file).items()}
781
+ index = self.atom_serial_number - 1
782
+ new_data["atom_site"] = {
783
+ k: [v[i] for i in index] for k, v in new_data["atom_site"].items()
784
+ }
785
+ new_data["atom_site"]["Cartn_x"] = data["Cartn_x"]
786
+ new_data["atom_site"]["Cartn_y"] = data["Cartn_y"]
787
+ new_data["atom_site"]["Cartn_z"] = data["Cartn_z"]
788
+ output_data = new_data
789
+ except Exception:
790
+ pass
791
+
792
+ ret = ""
793
+ for category, subdict in output_data.items():
794
+ if not len(subdict):
795
+ continue
796
+
797
+ ret += "#\n"
798
+ is_loop = isinstance(subdict[list(subdict.keys())[0]], list)
799
+ if not is_loop:
800
+ for k in subdict:
801
+ ret += f"_{category}.{k}\t{subdict[k]}\n"
802
+ else:
803
+ ret += "loop_\n"
804
+ ret += "".join([f"_{category}.{k}\n" for k in subdict])
805
+
806
+ subdict = {
807
+ k: [_format_string(s) for s in v] for k, v in subdict.items()
808
+ }
809
+ key_length = {
810
+ key: len(max(value, key=lambda x: len(x), default=""))
811
+ for key, value in subdict.items()
812
+ }
813
+ padded_subdict = {
814
+ key: [s.ljust(key_length[key] + 1) for s in values]
815
+ for key, values in subdict.items()
816
+ }
817
+
818
+ data = [
819
+ "".join([str(x) for x in content])
820
+ for content in zip(*padded_subdict.values())
821
+ ]
822
+ ret += "\n".join([entry for entry in data]) + "\n"
823
+
824
+ return ret
825
+
826
+ def _write_gro(self) -> str:
827
+ """
828
+ Generate a GRO format string representation of the structure.
829
+
830
+ Returns
831
+ -------
832
+ str
833
+ String representation of the structure in GRO format.
834
+ """
835
+ ret = ""
836
+ gro_files = np.unique(self.segment_identifier)
837
+ for index, gro_file in enumerate(gro_files):
838
+ subset = self[self.segment_identifier == gro_file]
839
+
840
+ title = self.metadata.get("title", "Missing title")
841
+ box_vectors = self.metadata.get("box_vectors")
842
+ try:
843
+ title = title[index]
844
+ box_vectors = box_vectors[index]
845
+ except Exception:
846
+ pass
847
+
848
+ if box_vectors is None:
849
+ box_vectors = [0.0, 0.0, 0.0]
850
+
851
+ num_atoms = subset.atom_coordinate.shape[0]
852
+ lines = [title, f"{num_atoms}"]
853
+ for i in range(num_atoms):
854
+ res_num = subset.residue_sequence_number[i]
855
+ res_name = subset.residue_name[i]
856
+ atom_name = subset.atom_name[i]
857
+ atom_num = subset.atom_serial_number[i]
858
+
859
+ x, y, z = subset.atom_coordinate[i]
860
+ coord = f"{atom_num % 100000:5d}{x:8.3f}{y:8.3f}{z:8.3f}"
861
+ line = f"{res_num % 100000:5d}{res_name:5s}{atom_name:5s}{coord}"
862
+
863
+ if "velocity" in subset.metadata:
864
+ vx, vy, vz = subset.metadata["velocity"][i]
865
+ line += f"{vx:8.4f}{vy:8.4f}{vz:8.4f}"
866
+
867
+ lines.append(line)
868
+
869
+ lines.append(" ".join(f"{v:.5f}" for v in box_vectors))
870
+ ret += "\n".join(lines) + "\n"
871
+ return ret
872
+
873
+ def subset_by_chain(self, chain: str = None) -> "Structure":
874
+ """
875
+ Return a subset of the structure that contains only atoms belonging to
876
+ a specific chain. If no chain is specified, all chains are returned.
877
+
878
+ Parameters
879
+ ----------
880
+ chain : str, optional
881
+ The chain identifier. If multiple chains should be selected they need
882
+ to be a comma separated string, e.g. 'A,B,CE'. If chain None,
883
+ all chains are returned. Default is None.
884
+
885
+ Returns
886
+ -------
887
+ :py:class:`Structure`
888
+ A subset of the class instance containing only the specified chains.
889
+
890
+ Raises
891
+ ------
892
+ ValueError
893
+ If none of the specified chains exist.
894
+
895
+ Examples
896
+ --------
897
+ >>> from importlib_resources import files
898
+ >>> from tme import Structure
899
+ >>> fname = str(files("tests.data").joinpath("Structures/5khe.cif"))
900
+ >>> structure = Structure.from_file(filename=fname)
901
+ >>> structure.subset_by_chain(chain="A") # Keep A
902
+ >>> structure.subset_by_chain(chain="A,B") # Keep A and B
903
+ >>> structure.subset_by_chain(chain="B,C") # Keep B, C does not exist
904
+ """
905
+ chain = np.unique(self.chain_identifier) if chain is None else chain.split(",")
906
+ keep = np.isin(self.chain_identifier, chain)
907
+ return self[keep]
908
+
909
+ def subset_by_range(
910
+ self,
911
+ start: int,
912
+ stop: int,
913
+ chain: str = None,
914
+ ) -> "Structure":
915
+ """
916
+ Return a subset of the structure within a specific range of residues.
917
+
918
+ Parameters
919
+ ----------
920
+ start : int
921
+ The starting residue sequence number.
922
+ stop : int
923
+ The ending residue sequence number.
924
+ chain : str, optional
925
+ The chain identifier. If multiple chains should be selected they need
926
+ to be a comma separated string, e.g. 'A,B,CE'. If chain None,
927
+ all chains are returned. Default is None.
928
+
929
+ Returns
930
+ -------
931
+ :py:class:`Structure`
932
+ A subset of the original structure within the specified residue range.
933
+
934
+ Raises
935
+ ------
936
+ ValueError
937
+ If none of the specified residue chain combinations exist.
938
+
939
+ Examples
940
+ --------
941
+ >>> from importlib_resources import files
942
+ >>> from tme import Structure
943
+ >>> fname = str(files("tests.data").joinpath("Structures/5khe.cif"))
944
+ >>> structure = Structure.from_file(filename=fname)
945
+ >>> structure.subset_by_range(chain="A",start=150,stop=180)
946
+ """
947
+ ret = self.subset_by_chain(chain=chain)
948
+ keep = np.logical_and(
949
+ ret.residue_sequence_number >= start, ret.residue_sequence_number <= stop
950
+ )
951
+ return ret[keep]
952
+
953
+ def center_of_mass(self, weight_type: str = "atomic_weight") -> NDArray:
954
+ """
955
+ Calculate the center of mass of the structure.
956
+
957
+ Parameters
958
+ ----------
959
+ weight_type : str, optional
960
+ The type of weights to return. This can either be 'atomic_weight',
961
+ 'atomic_number', or 'euqual'. Defaults to 'atomic_weight'
962
+
963
+
964
+ Returns
965
+ -------
966
+ NDArray
967
+ The center of mass of the structure.
968
+
969
+ Examples
970
+ --------
971
+ >>> from importlib_resources import files
972
+ >>> from tme import Structure
973
+ >>> fname = str(files("tests.data").joinpath("Structures/5khe.cif"))
974
+ >>> structure = Structure.from_file(filename=fname)
975
+ >>> structure.center_of_mass()
976
+ array([-0.89391639, 29.94908928, -2.64736741])
977
+ """
978
+ atoms = self.element_symbol
979
+ match weight_type:
980
+ case "atomic_weight":
981
+ weights = [self._elements[atom].atomic_weight for atom in atoms]
982
+ case "atomic_number":
983
+ weights = [self._elements[atom].atomic_number for atom in atoms]
984
+ case "equal":
985
+ weights = np.ones((len(atoms)))
986
+ case _:
987
+ raise NotImplementedError(
988
+ "weight_type can be 'atomic_weight', 'atomic_number' or 'equal."
989
+ )
990
+ return np.dot(self.atom_coordinate.T, weights) / np.sum(weights)
991
+
992
+ def rigid_transform(
993
+ self,
994
+ rotation_matrix: NDArray = None,
995
+ translation: NDArray = None,
996
+ use_geometric_center: bool = False,
997
+ ) -> "Structure":
998
+ """
999
+ Performs a rigid transform of internal structure coordinates.
1000
+
1001
+ Parameters
1002
+ ----------
1003
+ rotation_matrix : NDArray, optional
1004
+ The rotation matrix to apply to the coordinates, defaults to identity.
1005
+ translation : NDArray, optional
1006
+ The vector to translate the coordinates by, defaults to 0.
1007
+ use_geometric_center : bool, optional
1008
+ Whether to use geometric or coordinate center.
1009
+
1010
+ Returns
1011
+ -------
1012
+ Structure
1013
+ The transformed instance of :py:class:`Structure`.
1014
+
1015
+ Examples
1016
+ --------
1017
+ >>> from importlib_resources import files
1018
+ >>> from tme import Structure
1019
+ >>> from tme.matching_utils import get_rotation_matrices
1020
+ >>> fname = str(files("tests.data").joinpath("Structures/5khe.cif"))
1021
+ >>> structure = Structure.from_file(filename=fname)
1022
+ >>> structure.rigid_transform(
1023
+ >>> rotation_matrix = get_rotation_matrices(60)[2],
1024
+ >>> translation = (0, 1, -5)
1025
+ >>> )
1026
+ """
1027
+ out = np.empty_like(self.atom_coordinate.T)
1028
+ if translation is None:
1029
+ translation = np.zeros((self.atom_coordinate.shape[1]))
1030
+ if rotation_matrix is None:
1031
+ rotation_matrix = np.eye(self.atom_coordinate.shape[1])
1032
+
1033
+ rigid_transform(
1034
+ coordinates=self.atom_coordinate.T,
1035
+ rotation_matrix=rotation_matrix,
1036
+ translation=translation,
1037
+ out=out,
1038
+ use_geometric_center=use_geometric_center,
1039
+ )
1040
+ ret = self.copy()
1041
+ ret.atom_coordinate = out.T.copy()
1042
+ return ret
1043
+
1044
+ def centered(self) -> Tuple["Structure", NDArray]:
1045
+ """
1046
+ Shifts the structure analogous to :py:meth:`tme.density.Density.centered`.
1047
+
1048
+ Returns
1049
+ -------
1050
+ Structure
1051
+ A copy of the class instance whose data center of mass is in the
1052
+ center of the data array.
1053
+ NDArray
1054
+ The coordinate translation.
1055
+
1056
+ See Also
1057
+ --------
1058
+ :py:meth:`tme.density.Density.centered`
1059
+
1060
+ Examples
1061
+ --------
1062
+ >>> from importlib_resources import files
1063
+ >>> from tme import Structure
1064
+ >>> fname = str(files("tests.data").joinpath("Structures/5khe.cif"))
1065
+ >>> structure = Structure.from_file(filename=fname)
1066
+ >>> centered_structure, translation = structure.centered()
1067
+ >>> translation
1068
+ array([34.89391639, 4.05091072, 36.64736741])
1069
+ """
1070
+ center_of_mass = self.center_of_mass()
1071
+ enclosing_box = minimum_enclosing_box(coordinates=self.atom_coordinate.T)
1072
+ shift = np.subtract(np.divide(enclosing_box, 2), center_of_mass)
1073
+
1074
+ transformed_structure = self.rigid_transform(
1075
+ translation=shift, rotation_matrix=np.eye(shift.size)
1076
+ )
1077
+
1078
+ return transformed_structure, shift
1079
+
1080
+ def _coordinate_to_position(
1081
+ self,
1082
+ shape: Tuple[int],
1083
+ sampling_rate: Tuple[float],
1084
+ origin: Tuple[float],
1085
+ ) -> (NDArray, Tuple[str], Tuple[int], float, Tuple[float]):
1086
+ """
1087
+ Converts coordinates to positions.
1088
+
1089
+ Parameters
1090
+ ----------
1091
+ shape : Tuple[int,]
1092
+ The desired shape of the output array.
1093
+ sampling_rate : float
1094
+ The sampling rate of the output array in unit of self.atom_coordinate.
1095
+ origin : Tuple[float,]
1096
+ The origin of the coordinate system.
1097
+
1098
+ Returns
1099
+ -------
1100
+ Tuple[NDArray, List[str], Tuple[int, ], float, Tuple[float,]]
1101
+ Returns positions, atom_types, shape, sampling_rate, and origin.
1102
+ """
1103
+ coordinates = self.atom_coordinate.copy()
1104
+ atom_types = self.element_symbol.copy()
1105
+
1106
+ coordinates = coordinates
1107
+ sampling_rate = 1 if sampling_rate is None else sampling_rate
1108
+ adjust_origin = origin is not None and shape is None
1109
+ origin = coordinates.min(axis=0) if origin is None else origin
1110
+ positions = (coordinates - origin) / sampling_rate
1111
+ positions = np.rint(positions).astype(int)
1112
+
1113
+ if adjust_origin:
1114
+ left_shift = positions.min(axis=0)
1115
+ positions -= left_shift
1116
+ shape = positions.max(axis=0) + 1
1117
+ origin = origin + np.multiply(left_shift, sampling_rate)
1118
+
1119
+ if shape is None:
1120
+ shape = positions.max(axis=0) + 1
1121
+
1122
+ valid_positions = np.sum(
1123
+ np.logical_and(positions < shape, positions >= 0), axis=1
1124
+ )
1125
+
1126
+ positions = positions[valid_positions == positions.shape[1], :]
1127
+ atom_types = atom_types[valid_positions == positions.shape[1]]
1128
+
1129
+ self.metadata["nAtoms_outOfBound"] = 0
1130
+ if positions.shape[0] != coordinates.shape[0]:
1131
+ out_of_bounds = coordinates.shape[0] - positions.shape[0]
1132
+ print(f"{out_of_bounds}/{coordinates.shape[0]} atoms were out of bounds.")
1133
+ self.metadata["nAtoms_outOfBound"] = out_of_bounds
1134
+
1135
+ return positions, atom_types, shape, sampling_rate, origin
1136
+
1137
+ def _position_to_vdw_sphere(
1138
+ self,
1139
+ positions: Tuple[float],
1140
+ atoms: Tuple[str],
1141
+ sampling_rate: Tuple[float],
1142
+ volume: NDArray,
1143
+ ) -> None:
1144
+ """
1145
+ Updates a volume with van der Waals spheres.
1146
+
1147
+ Parameters
1148
+ ----------
1149
+ positions : Tuple[float, float, float]
1150
+ The positions of the atoms.
1151
+ atoms : Tuple[str]
1152
+ The types of the atoms.
1153
+ sampling_rate : float
1154
+ The desired sampling rate in unit of self.atom_coordinate of the
1155
+ output array.
1156
+ volume : NDArray
1157
+ The volume to update.
1158
+ """
1159
+ index_dict, vdw_rad, shape = {}, {}, volume.shape
1160
+ for atom_index, atom_position in enumerate(positions):
1161
+ atom_type = atoms[atom_index]
1162
+ if atom_type not in index_dict.keys():
1163
+ atom_vdwr = np.ceil(
1164
+ np.divide(self._elements[atom_type].vdwr, (sampling_rate * 100))
1165
+ ).astype(int)
1166
+
1167
+ vdw_rad[atom_type] = atom_vdwr
1168
+ atom_slice = tuple(slice(-k, k + 1) for k in atom_vdwr)
1169
+ distances = np.linalg.norm(
1170
+ np.divide(
1171
+ np.mgrid[atom_slice],
1172
+ atom_vdwr.reshape((-1,) + (1,) * volume.ndim),
1173
+ ),
1174
+ axis=0,
1175
+ )
1176
+ index_dict[atom_type] = (distances <= 1).astype(volume.dtype)
1177
+
1178
+ footprint = index_dict[atom_type]
1179
+ start = np.maximum(np.subtract(atom_position, vdw_rad[atom_type]), 0)
1180
+ stop = np.minimum(np.add(atom_position, vdw_rad[atom_type]) + 1, shape)
1181
+ volume_slice = tuple(slice(*coord) for coord in zip(start, stop))
1182
+
1183
+ start_index = np.maximum(-np.subtract(atom_position, vdw_rad[atom_type]), 0)
1184
+ stop_index = np.add(
1185
+ footprint.shape,
1186
+ np.minimum(
1187
+ np.subtract(shape, np.add(atom_position, vdw_rad[atom_type]) + 1), 0
1188
+ ),
1189
+ )
1190
+ index_slice = tuple(slice(*coord) for coord in zip(start_index, stop_index))
1191
+ volume[volume_slice] += footprint[index_slice]
1192
+
1193
+ def _position_to_scattering_factors(
1194
+ self,
1195
+ positions: NDArray,
1196
+ atoms: NDArray,
1197
+ sampling_rate: NDArray,
1198
+ volume: NDArray,
1199
+ lowpass_filter: bool = True,
1200
+ downsampling_factor: float = 1.35,
1201
+ source: str = "peng1995",
1202
+ ) -> None:
1203
+ """
1204
+ Updates a volume with scattering factors.
1205
+
1206
+ Parameters
1207
+ ----------
1208
+ positions : NDArray
1209
+ The positions of the atoms.
1210
+ atoms : NDArray
1211
+ Element symbols.
1212
+ sampling_rate : float
1213
+ Sampling rate that was used to convert coordinates to positions.
1214
+ volume : NDArray
1215
+ The volume to update.
1216
+ lowpass_filter : NDArray
1217
+ Whether the scattering factors should be lowpass filtered.
1218
+ downsampling_factor : NDArray
1219
+ Downsampling factor for scattering factor computation.
1220
+ source : str
1221
+ Which scattering factors to use
1222
+
1223
+ Reference
1224
+ ---------
1225
+ https://github.com/I2PC/xmipp.
1226
+ """
1227
+ scattering_profiles, shape = dict(), volume.shape
1228
+ for atom_index, point in enumerate(positions):
1229
+ if atoms[atom_index] not in scattering_profiles:
1230
+ spline = atom_profile(
1231
+ atom=atoms[atom_index],
1232
+ M=downsampling_factor,
1233
+ method=source,
1234
+ lfilter=lowpass_filter,
1235
+ )
1236
+ scattering_profiles.update({atoms[atom_index]: spline})
1237
+
1238
+ atomic_radius = np.divide(
1239
+ self._elements[atoms[atom_index]].vdwr, sampling_rate * 100
1240
+ )
1241
+ starts = np.maximum(np.ceil(point - atomic_radius), 0).astype(int)
1242
+ stops = np.minimum(np.floor(point + atomic_radius), shape).astype(int)
1243
+
1244
+ grid_index = np.meshgrid(
1245
+ *[range(start, stop) for start, stop in zip(starts, stops)]
1246
+ )
1247
+ distances = np.einsum(
1248
+ "aijk->ijk",
1249
+ np.array([(grid_index[i] - point[i]) ** 2 for i in range(len(point))]),
1250
+ dtype=np.float64,
1251
+ )
1252
+ distances = np.sqrt(distances)
1253
+ if not len(distances):
1254
+ grid_index, distances = point, 0
1255
+ np.add.at(
1256
+ volume,
1257
+ tuple(grid_index),
1258
+ scattering_profiles[atoms[atom_index]](distances),
1259
+ )
1260
+
1261
+ @staticmethod
1262
+ def _position_to_molmap(
1263
+ positions: NDArray,
1264
+ weights: Tuple[float],
1265
+ resolution: float = 4,
1266
+ sigma_factor: float = 1 / (np.pi * np.sqrt(2)),
1267
+ cutoff_value: float = 4.0,
1268
+ sampling_rate: float = None,
1269
+ ) -> NDArray:
1270
+ """
1271
+ Simulates electron densities analogous to Chimera's molmap function [1]_.
1272
+
1273
+ Parameters
1274
+ ----------
1275
+ positions : NDArray
1276
+ Array containing atomic positions in x,y,z format (n,d).
1277
+ weights : tuple of float
1278
+ The weights to use for the entries in positions.
1279
+ resolution : float, optional
1280
+ The product of resolution and sigma_factor gives the sigma used to
1281
+ compute the discretized Gaussian.
1282
+ sigma_factor : float, optional
1283
+ The factor used with resolution to compute sigma. Default is 1 / (π√2).
1284
+ cutoff_value : float, optional
1285
+ The cutoff value for the Gaussian kernel. Default is 4.0.
1286
+ sampling_rate : float, optional
1287
+ Sampling rate along each dimension. One third of resolution by default.
1288
+
1289
+ References
1290
+ ----------
1291
+ ..[1] https://www.cgl.ucsf.edu/chimera/docs/UsersGuide/midas/molmap.html
1292
+
1293
+ Returns
1294
+ -------
1295
+ NDArray
1296
+ A numpy array containing the simulated electron densities.
1297
+ """
1298
+ if sampling_rate is None:
1299
+ sampling_rate = resolution / 3
1300
+
1301
+ pad = int(3 * resolution)
1302
+ sigma = sigma_factor * resolution
1303
+ sigma_grid = sigma / sampling_rate
1304
+
1305
+ # Limit padding to numerically stable values
1306
+ smax = np.max(sigma_grid)
1307
+ arr = np.arange(0, pad)
1308
+ gaussian = (
1309
+ np.exp(-0.5 * (arr / smax) ** 2)
1310
+ * np.power(2 * np.pi, -1.5)
1311
+ * np.power(sigma, -3.0)
1312
+ )
1313
+ pad_cutoff = np.max(arr[gaussian > 1e-8])
1314
+ if arr.size != 0:
1315
+ pad = int(pad_cutoff) + 1
1316
+
1317
+ positions = positions
1318
+ origin = positions.min(axis=0) - pad * sampling_rate
1319
+ positions = np.rint(np.divide((positions - origin), sampling_rate)).astype(int)
1320
+
1321
+ shape = positions.max(axis=0).astype(int) + pad + 1
1322
+ out = np.zeros(shape, dtype=np.float32)
1323
+ np.add.at(out, tuple(positions.T), weights)
1324
+
1325
+ out = Preprocessor().gaussian_filter(
1326
+ template=out, sigma=sigma_grid, cutoff_value=cutoff_value
1327
+ )
1328
+ return out, origin
1329
+
1330
+ def _get_atom_weights(
1331
+ self, atoms: Tuple[str] = None, weight_type: str = "atomic_weight"
1332
+ ) -> Tuple[float]:
1333
+ """
1334
+ Returns weights of individual atoms according to a specified weight type.
1335
+
1336
+ Parameters
1337
+ ----------
1338
+ atoms : Tuple of strings, optional
1339
+ The atoms to get the weights for. If None, weights for all atoms
1340
+ are used. Default is None.
1341
+
1342
+ weight_type : str, optional
1343
+ The type of weights to return. This can either be 'atomic_weight',
1344
+ 'atomic_number', or 'van_der_waals_radius'. Default is 'atomic_weight'.
1345
+
1346
+ Returns
1347
+ -------
1348
+ List[float]
1349
+ A list containing the weights of the atoms.
1350
+ """
1351
+ atoms = self.element_symbol if atoms is None else atoms
1352
+ match weight_type:
1353
+ case "atomic_weight":
1354
+ weight = [self._elements[atom].atomic_weight for atom in atoms]
1355
+ case "atomic_number":
1356
+ weight = [self._elements[atom].atomic_number for atom in atoms]
1357
+ case _:
1358
+ raise NotImplementedError(
1359
+ "weight_type can either be 'atomic_weight' or 'atomic_number'"
1360
+ )
1361
+ return weight
1362
+
1363
+ def to_volume(
1364
+ self,
1365
+ shape: Tuple[int] = None,
1366
+ sampling_rate: Tuple[float] = None,
1367
+ origin: Tuple[float] = None,
1368
+ chain: str = None,
1369
+ weight_type: str = "atomic_weight",
1370
+ weight_type_args: Dict = dict(),
1371
+ ) -> Tuple[NDArray, NDArray, NDArray]:
1372
+ """
1373
+ Maps class instance to a volume.
1374
+
1375
+ Parameters
1376
+ ----------
1377
+ shape : tuple of ints, optional
1378
+ Output array shape in (x,y,z) form.
1379
+ sampling_rate : tuple of float, optional
1380
+ Sampling rate of the output array in units of
1381
+ :py:attr:`Structure.atom_coordinate`
1382
+ origin : tuple of floats, optional
1383
+ Origin of the coordinate system in (x,y,z) form.
1384
+ chain : str, optional
1385
+ Chains to be included. Either single or comma separated string of chains.
1386
+ Defaults to None which returns all chains.
1387
+ weight_type : str, optional
1388
+ Weight given to individual atoms. Supported weights are:
1389
+
1390
+ +----------------------------+---------------------------------------+
1391
+ | atomic_weight | Using element weight point mass |
1392
+ +----------------------------+---------------------------------------+
1393
+ | atomic_number | Using atomic number point mass |
1394
+ +----------------------------+---------------------------------------+
1395
+ | gaussian | Using element weighted Gaussian mass |
1396
+ +----------------------------+---------------------------------------+
1397
+ | van_der_waals_radius | Using binary van der waal spheres |
1398
+ +----------------------------+---------------------------------------+
1399
+ | scattering_factors | Using experimental scattering factors |
1400
+ +----------------------------+---------------------------------------+
1401
+ | lowpass_scattering_factors | Lowpass filtered scattering_factors |
1402
+ +----------------------------+---------------------------------------+
1403
+ weight_type_args : dict, optional
1404
+ Additional arguments used for individual weight_types. `gaussian`
1405
+ accepts ``resolution``, `scattering` accepts ``method``.
1406
+
1407
+ Returns
1408
+ -------
1409
+ Tuple[NDArray, NDArray, NDArray]
1410
+ Volume, origin and sampling_rate.
1411
+
1412
+ Examples
1413
+ --------
1414
+ >>> from importlib_resources import files
1415
+ >>> from tme import Structure
1416
+ >>> fname = str(files("tests.data").joinpath("Structures/5khe.cif"))
1417
+ >>> structure = Structure.from_file(filename=fname)
1418
+ >>> vol, origin, sampling = structure.to_volume()
1419
+ >>> vol.shape, origin, sampling
1420
+ ((59, 35, 53), array([-30.71, 12.42, -27.15]), array([1., 1., 1.]))
1421
+ >>> vol, origin, sampling = structure.to_volume(sampling_rate=(2.2,1,3))
1422
+ ((27, 35, 18), array([-30.71, 12.42, -27.15]), array([2.2, 1. , 3. ]))
1423
+
1424
+ ``sampling_rate`` and ``origin`` can be set to ensure correct alignment
1425
+ with corresponding density maps such as the ones at EMDB. Analogous to
1426
+ :py:meth:`Structure.subset_by_chain` only parts of the structure can be
1427
+ mapped onto grids using a variety of weighting schemes
1428
+
1429
+ >>> structure.to_volume(weight_type="van_der_waals_radius")
1430
+ >>> structure.to_volume(
1431
+ >>> weight_type="lowpass_scattering_factors",
1432
+ >>> method_args={"source" : "dt1969", "downsampling_factor" : 1.35},
1433
+ >>> )
1434
+ """
1435
+ _weight_types = {
1436
+ "gaussian",
1437
+ "atomic_weight",
1438
+ "atomic_number",
1439
+ "van_der_waals_radius",
1440
+ "scattering_factors",
1441
+ "lowpass_scattering_factors",
1442
+ }
1443
+ _weight_string = ",".join([f"'{x}'" for x in _weight_types])
1444
+ if weight_type not in _weight_types:
1445
+ raise NotImplementedError(f"weight_type needs to be in {_weight_string}")
1446
+
1447
+ if sampling_rate is None:
1448
+ sampling_rate = np.ones(self.atom_coordinate.shape[1])
1449
+ sampling_rate = np.array(sampling_rate)
1450
+ if sampling_rate.size == 1:
1451
+ sampling_rate = np.repeat(sampling_rate, self.atom_coordinate.shape[1])
1452
+ elif sampling_rate.size != self.atom_coordinate.shape[1]:
1453
+ raise ValueError(
1454
+ "sampling_rate should either be single value of array with"
1455
+ f"size {self.atom_coordinate.shape[1]}."
1456
+ )
1457
+
1458
+ temp = self.subset_by_chain(chain=chain)
1459
+ positions, atoms, _shape, sampling_rate, origin = temp._coordinate_to_position(
1460
+ shape=shape, sampling_rate=sampling_rate, origin=origin
1461
+ )
1462
+ volume = np.zeros(_shape, dtype=np.float32)
1463
+ if weight_type in ("atomic_weight", "atomic_number"):
1464
+ weights = temp._get_atom_weights(atoms=atoms, weight_type=weight_type)
1465
+ np.add.at(volume, tuple(positions.T), weights)
1466
+ elif weight_type == "van_der_waals_radius":
1467
+ self._position_to_vdw_sphere(positions, atoms, sampling_rate, volume)
1468
+ elif weight_type == "scattering_factors":
1469
+ self._position_to_scattering_factors(
1470
+ positions,
1471
+ atoms,
1472
+ sampling_rate,
1473
+ volume,
1474
+ lowpass_filter=False,
1475
+ **weight_type_args,
1476
+ )
1477
+ elif weight_type == "lowpass_scattering_factors":
1478
+ self._position_to_scattering_factors(
1479
+ positions,
1480
+ atoms,
1481
+ sampling_rate,
1482
+ volume,
1483
+ lowpass_filter=True,
1484
+ **weight_type_args,
1485
+ )
1486
+ elif weight_type == "gaussian":
1487
+ volume, origin = self._position_to_molmap(
1488
+ positions=temp.atom_coordinate,
1489
+ weights=temp._get_atom_weights(
1490
+ atoms=atoms, weight_type="atomic_number"
1491
+ ),
1492
+ sampling_rate=sampling_rate,
1493
+ **weight_type_args,
1494
+ )
1495
+
1496
+ self.metadata.update(temp.metadata)
1497
+ return volume, origin, sampling_rate
1498
+
1499
+ @classmethod
1500
+ def compare_structures(
1501
+ cls,
1502
+ structure1: "Structure",
1503
+ structure2: "Structure",
1504
+ origin: NDArray = None,
1505
+ sampling_rate: float = None,
1506
+ weighted: bool = False,
1507
+ ) -> float:
1508
+ """
1509
+ Compute root mean square deviation (RMSD) between two structures with the
1510
+ same number of atoms.
1511
+
1512
+ Parameters
1513
+ ----------
1514
+ structure1, structure2 : :py:class:`Structure`
1515
+ Structure instances to compare.
1516
+ origin : tuple of floats, optional
1517
+ Coordinate system origin. For computing RMSD on discretized grids.
1518
+ sampling_rate : tuple of floats, optional
1519
+ Sampling rate in units of :py:attr:`atom_coordinate`.
1520
+ For computing RMSD on discretized grids.
1521
+ weighted : bool, optional
1522
+ Whether atoms should be weighted acoording to their atomic weight.
1523
+
1524
+ Returns
1525
+ -------
1526
+ float
1527
+ Root Mean Square Deviation between input structures.
1528
+
1529
+ Examples
1530
+ --------
1531
+ >>> from importlib_resources import files
1532
+ >>> from tme.matching_utils import get_rotation_matrices
1533
+ >>> from tme import Structure
1534
+ >>> fname = str(files("tests.data").joinpath("Structures/5khe.cif"))
1535
+ >>> structure = Structure.from_file(filename=fname)
1536
+ >>> transformed = structure.rigid_transform(
1537
+ >>> rotation_matrix = get_rotation_matrices(60)[2],
1538
+ >>> translation = (0, 1, -5)
1539
+ >>> )
1540
+ >>> Structure.compare_structures(structure, transformed)
1541
+ 31.35238
1542
+ >>> Structure.compare_structures(structure, structure)
1543
+ 0.0
1544
+ """
1545
+ if origin is None:
1546
+ origin = np.zeros(structure1.atom_coordinate.shape[1])
1547
+
1548
+ coordinates1 = structure1.atom_coordinate
1549
+ coordinates2 = structure2.atom_coordinate
1550
+ atoms1, atoms2 = structure1.element_symbol, structure2.element_symbol
1551
+ if sampling_rate is not None:
1552
+ coordinates1 = np.rint(
1553
+ np.divide(np.subtract(coordinates1, origin), sampling_rate)
1554
+ ).astype(int)
1555
+ coordinates2 = np.rint(
1556
+ np.divide(np.subtract(coordinates2, origin), sampling_rate)
1557
+ ).astype(int)
1558
+
1559
+ weights1 = np.ones_like(structure1.atom_coordinate.shape[0])
1560
+ weights2 = np.ones_like(structure2.atom_coordinate.shape[0])
1561
+ if weighted:
1562
+ weights1 = np.array(structure1._get_atom_weights(atoms=atoms1))
1563
+ weights2 = np.array(structure2._get_atom_weights(atoms=atoms2))
1564
+
1565
+ if not np.allclose(coordinates1.shape, coordinates2.shape):
1566
+ raise ValueError(
1567
+ "Input structures need to have the same number of coordinates."
1568
+ )
1569
+ if not np.allclose(weights1.shape, weights2.shape):
1570
+ raise ValueError("Input structures need to have the same number of atoms.")
1571
+
1572
+ squared_diff = np.sum(np.square(coordinates1 - coordinates2), axis=1)
1573
+ weighted_quared_diff = squared_diff * ((weights1 + weights2) / 2)
1574
+ rmsd = np.sqrt(np.mean(weighted_quared_diff))
1575
+
1576
+ return rmsd
1577
+
1578
+ @classmethod
1579
+ def align_structures(
1580
+ cls,
1581
+ structure1: "Structure",
1582
+ structure2: "Structure",
1583
+ origin: NDArray = None,
1584
+ sampling_rate: float = None,
1585
+ weighted: bool = False,
1586
+ ) -> Tuple["Structure", float]:
1587
+ """
1588
+ Align ``structure2`` to ``structure1`` using the Kabsch Algorithm. Both
1589
+ structures need to have the same number of atoms.
1590
+
1591
+ Parameters
1592
+ ----------
1593
+ structure1, structure2 : :py:class:`Structure`
1594
+ Structure instances to align.
1595
+ origin : tuple of floats, optional
1596
+ Coordinate system origin. For computing RMSD on discretized grids.
1597
+ sampling_rate : tuple of floats, optional
1598
+ Sampling rate in units of :py:attr:`atom_coordinate`.
1599
+ For computing RMSD on discretized grids.
1600
+ weighted : bool, optional
1601
+ Whether atoms should be weighted by their atomic weight.
1602
+
1603
+ Returns
1604
+ -------
1605
+ :py:class:`Structure`
1606
+ ``structure2`` aligned to ``structure1``.
1607
+ float
1608
+ Alignment RMSD
1609
+
1610
+ Examples
1611
+ --------
1612
+ >>> from importlib_resources import files
1613
+ >>> from tme import Structure
1614
+ >>> from tme.matching_utils import get_rotation_matrices
1615
+ >>> fname = str(files("tests.data").joinpath("Structures/5khe.cif"))
1616
+ >>> structure = Structure.from_file(filename=fname)
1617
+ >>> transformed = structure.rigid_transform(
1618
+ >>> rotation_matrix = get_rotation_matrices(60)[2],
1619
+ >>> translation = (0, 1, -5)
1620
+ >>> )
1621
+ >>> aligned, rmsd = Structure.align_structures(structure, transformed)
1622
+ Initial RMSD: 31.07189 - Final RMSD: 0.00000
1623
+ """
1624
+ if origin is None:
1625
+ origin = np.minimum(
1626
+ structure1.atom_coordinate.min(axis=0),
1627
+ structure2.atom_coordinate.min(axis=0),
1628
+ ).astype(int)
1629
+
1630
+ initial_rmsd = cls.compare_structures(
1631
+ structure1=structure1,
1632
+ structure2=structure2,
1633
+ origin=origin,
1634
+ sampling_rate=sampling_rate,
1635
+ weighted=weighted,
1636
+ )
1637
+
1638
+ reference = structure1.atom_coordinate.copy()
1639
+ query = structure2.atom_coordinate.copy()
1640
+ if sampling_rate is not None:
1641
+ reference, atoms1, shape, _, _ = structure1._coordinate_to_position(
1642
+ shape=None, sampling_rate=sampling_rate, origin=origin
1643
+ )
1644
+ query, atoms2, shape, _, _ = structure2._coordinate_to_position(
1645
+ shape=None, sampling_rate=sampling_rate, origin=origin
1646
+ )
1647
+
1648
+ reference_mean = reference.mean(axis=0)
1649
+ query_mean = query.mean(axis=0)
1650
+
1651
+ reference = reference - reference_mean
1652
+ query = query - query_mean
1653
+
1654
+ corr = np.dot(query.T, reference)
1655
+ U, S, Vh = np.linalg.svd(corr)
1656
+
1657
+ rotation = np.dot(Vh.T, U.T).T
1658
+ if np.linalg.det(rotation) < 0:
1659
+ Vh[2, :] *= -1
1660
+ rotation = np.dot(Vh.T, U.T).T
1661
+
1662
+ translation = reference_mean - np.dot(query_mean, rotation)
1663
+
1664
+ temp = structure1.copy()
1665
+ temp.atom_coordinate = reference + reference_mean
1666
+ ret = structure2.copy()
1667
+ ret.atom_coordinate = np.dot(query + query_mean, rotation) + translation
1668
+
1669
+ final_rmsd = cls.compare_structures(
1670
+ structure1=temp,
1671
+ structure2=ret,
1672
+ origin=origin,
1673
+ sampling_rate=None,
1674
+ weighted=weighted,
1675
+ )
1676
+
1677
+ print(f"Initial RMSD: {initial_rmsd:.5f} - Final RMSD: {final_rmsd:.5f}")
1678
+
1679
+ return ret, final_rmsd
1680
+
1681
+ def align_to_axis(
1682
+ self, coordinates: NDArray = None, axis: int = 2, flip: bool = False, **kwargs
1683
+ ):
1684
+ if coordinates is None:
1685
+ coordinates = self.atom_coordinate
1686
+
1687
+ return align_to_axis(coordinates, axis=axis, flip=flip, **kwargs)
1688
+
1689
+
1690
+ @dataclass(frozen=True, repr=True)
1691
+ class _Elements:
1692
+ """Lookup table for chemical elements."""
1693
+
1694
+ Atom = namedtuple(
1695
+ "Atom",
1696
+ [
1697
+ "atomic_number",
1698
+ "atomic_radius",
1699
+ "lattice_constant",
1700
+ "lattice_structure",
1701
+ "vdwr",
1702
+ "covalent_radius_bragg",
1703
+ "atomic_weight",
1704
+ ],
1705
+ )
1706
+ _default = Atom(0, 0, 0, "Atom does not exist in ressource.", 0, 0, 0)
1707
+ _elements = {
1708
+ "H": Atom(1, 25, 3.75, "HEX", 110, np.nan, 1.008),
1709
+ "HE": Atom(2, 120, 3.57, "HEX", 140, np.nan, 4.002602),
1710
+ "LI": Atom(3, 145, 3.49, "BCC", 182, 150, 6.94),
1711
+ "BE": Atom(4, 105, 2.29, "HEX", 153, 115, 9.0121831),
1712
+ "B": Atom(5, 85, 8.73, "TET", 192, np.nan, 10.81),
1713
+ "C": Atom(6, 70, 3.57, "DIA", 170, 77, 12.011),
1714
+ "N": Atom(7, 65, 4.039, "HEX", 155, 65, 14.007),
1715
+ "O": Atom(8, 60, 6.83, "CUB", 152, 65, 15.999),
1716
+ "F": Atom(9, 50, np.nan, "MCL", 147, 67, 18.998403163),
1717
+ "NE": Atom(10, 160, 4.43, "FCC", 154, np.nan, 20.1797),
1718
+ "NA": Atom(11, 180, 4.23, "BCC", 227, 177, 22.98976928),
1719
+ "MG": Atom(12, 150, 3.21, "HEX", 173, 142, 24.305),
1720
+ "AL": Atom(13, 125, 4.05, "FCC", 184, 135, 26.9815385),
1721
+ "SI": Atom(14, 110, 5.43, "DIA", 210, 117, 28.085),
1722
+ "P": Atom(15, 100, 7.17, "CUB", 180, np.nan, 30.973761998),
1723
+ "S": Atom(16, 100, 10.47, "ORC", 180, 102, 32.06),
1724
+ "CL": Atom(17, 100, 6.24, "ORC", 175, 105, 35.45),
1725
+ "AR": Atom(18, 71, 5.26, "FCC", 188, np.nan, 39.948),
1726
+ "K": Atom(19, 220, 5.23, "BCC", 275, 207, 39.0983),
1727
+ "CA": Atom(20, 180, 5.58, "FCC", 231, 170, 40.078),
1728
+ "SC": Atom(21, 160, 3.31, "HEX", 215, np.nan, 44.955908),
1729
+ "TI": Atom(22, 140, 2.95, "HEX", 211, 140, 47.867),
1730
+ "V": Atom(23, 135, 3.02, "BCC", 207, np.nan, 50.9415),
1731
+ "CR": Atom(24, 140, 2.88, "BCC", 206, 140, 51.9961),
1732
+ "MN": Atom(25, 140, 8.89, "CUB", 205, 147, 54.938044),
1733
+ "FE": Atom(26, 140, 2.87, "BCC", 204, 140, 55.845),
1734
+ "CO": Atom(27, 135, 2.51, "HEX", 200, 137, 58.933194),
1735
+ "NI": Atom(28, 135, 3.52, "FCC", 197, 135, 58.6934),
1736
+ "CU": Atom(29, 135, 3.61, "FCC", 196, 137, 63.546),
1737
+ "ZN": Atom(30, 135, 2.66, "HEX", 201, 132, 65.38),
1738
+ "GA": Atom(31, 130, 4.51, "ORC", 187, np.nan, 69.723),
1739
+ "GE": Atom(32, 125, 5.66, "DIA", 211, np.nan, 72.63),
1740
+ "AS": Atom(33, 115, 4.13, "RHL", 185, 126, 74.921595),
1741
+ "SE": Atom(34, 115, 4.36, "HEX", 190, 117, 78.971),
1742
+ "BR": Atom(35, 115, 6.67, "ORC", 185, 119, 79.904),
1743
+ "KR": Atom(36, np.nan, 5.72, "FCC", 202, np.nan, 83.798),
1744
+ "RB": Atom(37, 235, 5.59, "BCC", 303, 225, 85.4678),
1745
+ "SR": Atom(38, 200, 6.08, "FCC", 249, 195, 87.62),
1746
+ "Y": Atom(39, 180, 3.65, "HEX", 232, np.nan, 88.90584),
1747
+ "ZR": Atom(40, 155, 3.23, "HEX", 223, np.nan, 91.224),
1748
+ "NB": Atom(41, 145, 3.3, "BCC", 218, np.nan, 92.90637),
1749
+ "MO": Atom(42, 145, 3.15, "BCC", 217, np.nan, 95.95),
1750
+ "TC": Atom(43, 135, 2.74, "HEX", 216, np.nan, 97.90721),
1751
+ "RU": Atom(44, 130, 2.7, "HEX", 213, np.nan, 101.07),
1752
+ "RH": Atom(45, 135, 3.8, "FCC", 210, np.nan, 102.9055),
1753
+ "PD": Atom(46, 140, 3.89, "FCC", 210, np.nan, 106.42),
1754
+ "AG": Atom(47, 160, 4.09, "FCC", 211, 177, 107.8682),
1755
+ "CD": Atom(48, 155, 2.98, "HEX", 218, 160, 112.414),
1756
+ "IN": Atom(49, 155, 4.59, "TET", 193, np.nan, 114.818),
1757
+ "SN": Atom(50, 145, 5.82, "TET", 217, 140, 118.71),
1758
+ "SB": Atom(51, 145, 4.51, "RHL", 206, 140, 121.76),
1759
+ "TE": Atom(52, 140, 4.45, "HEX", 206, 133, 127.6),
1760
+ "I": Atom(53, 140, 7.72, "ORC", 198, 140, 126.90447),
1761
+ "XE": Atom(54, np.nan, 6.2, "FCC", 216, np.nan, 131.293),
1762
+ "CS": Atom(55, 260, 6.05, "BCC", 343, 237, 132.90545196),
1763
+ "BA": Atom(56, 215, 5.02, "BCC", 268, 210, 137.327),
1764
+ "LA": Atom(57, 195, 3.75, "HEX", 243, np.nan, 138.90547),
1765
+ "CE": Atom(58, 185, 5.16, "FCC", 242, np.nan, 140.116),
1766
+ "PR": Atom(59, 185, 3.67, "HEX", 240, np.nan, 140.90766),
1767
+ "ND": Atom(60, 185, 3.66, "HEX", 239, np.nan, 144.242),
1768
+ "PM": Atom(61, 185, np.nan, "", 238, np.nan, 144.91276),
1769
+ "SM": Atom(62, 185, 9, "RHL", 236, np.nan, 150.36),
1770
+ "EU": Atom(63, 185, 4.61, "BCC", 235, np.nan, 151.964),
1771
+ "GD": Atom(64, 180, 3.64, "HEX", 234, np.nan, 157.25),
1772
+ "TB": Atom(65, 175, 3.6, "HEX", 233, np.nan, 158.92535),
1773
+ "DY": Atom(66, 175, 3.59, "HEX", 231, np.nan, 162.5),
1774
+ "HO": Atom(67, 175, 3.58, "HEX", 230, np.nan, 164.93033),
1775
+ "ER": Atom(68, 175, 3.56, "HEX", 229, np.nan, 167.259),
1776
+ "TM": Atom(69, 175, 3.54, "HEX", 227, np.nan, 168.93422),
1777
+ "YB": Atom(70, 175, 5.49, "FCC", 226, np.nan, 173.045),
1778
+ "LU": Atom(71, 175, 3.51, "HEX", 224, np.nan, 174.9668),
1779
+ "HF": Atom(72, 155, 3.2, "HEX", 223, np.nan, 178.49),
1780
+ "TA": Atom(73, 145, 3.31, "BCC", 222, np.nan, 180.94788),
1781
+ "W": Atom(74, 135, 3.16, "BCC", 218, np.nan, 183.84),
1782
+ "RE": Atom(75, 135, 2.76, "HEX", 216, np.nan, 186.207),
1783
+ "OS": Atom(76, 130, 2.74, "HEX", 216, np.nan, 190.23),
1784
+ "IR": Atom(77, 135, 3.84, "FCC", 213, np.nan, 192.217),
1785
+ "PT": Atom(78, 135, 3.92, "FCC", 213, np.nan, 195.084),
1786
+ "AU": Atom(79, 135, 4.08, "FCC", 214, np.nan, 196.966569),
1787
+ "HG": Atom(80, 150, 2.99, "RHL", 223, np.nan, 200.592),
1788
+ "TL": Atom(81, 190, 3.46, "HEX", 196, 190, 204.38),
1789
+ "PB": Atom(82, 180, 4.95, "FCC", 202, np.nan, 207.2),
1790
+ "BI": Atom(83, 160, 4.75, "RHL", 207, 148, 208.9804),
1791
+ "PO": Atom(84, 190, 3.35, "SC", 197, np.nan, 209),
1792
+ "AT": Atom(85, np.nan, np.nan, "", 202, np.nan, 210),
1793
+ "RN": Atom(86, np.nan, np.nan, "FCC", 220, np.nan, 222),
1794
+ "FR": Atom(87, np.nan, np.nan, "BCC", 348, np.nan, 223),
1795
+ "RA": Atom(88, 215, np.nan, "", 283, np.nan, 226),
1796
+ "AC": Atom(89, 195, 5.31, "FCC", 247, np.nan, 227),
1797
+ "TH": Atom(90, 180, 5.08, "FCC", 245, np.nan, 232.0377),
1798
+ "PA": Atom(91, 180, 3.92, "TET", 243, np.nan, 231.03588),
1799
+ "U": Atom(92, 175, 2.85, "ORC", 241, np.nan, 238.02891),
1800
+ "NP": Atom(93, 175, 4.72, "ORC", 239, np.nan, 237),
1801
+ "PU": Atom(94, 175, np.nan, "MCL", 243, np.nan, 244),
1802
+ "AM": Atom(95, 175, np.nan, "", 244, np.nan, 243),
1803
+ "CM": Atom(96, np.nan, np.nan, "", 245, np.nan, 247),
1804
+ "BK": Atom(97, np.nan, np.nan, "", 244, np.nan, 247),
1805
+ "CF": Atom(98, np.nan, np.nan, "", 245, np.nan, 251),
1806
+ "ES": Atom(99, np.nan, np.nan, "", 245, np.nan, 252),
1807
+ "FM": Atom(100, np.nan, np.nan, "", 245, np.nan, 257),
1808
+ "MD": Atom(101, np.nan, np.nan, "", 246, np.nan, 258),
1809
+ "NO": Atom(102, np.nan, np.nan, "", 246, np.nan, 259),
1810
+ "LR": Atom(103, np.nan, np.nan, "", 246, np.nan, 262),
1811
+ "RF": Atom(104, np.nan, np.nan, "", np.nan, np.nan, 267),
1812
+ "DB": Atom(105, np.nan, np.nan, "", np.nan, np.nan, 268),
1813
+ "SG": Atom(106, np.nan, np.nan, "", np.nan, np.nan, 271),
1814
+ "BH": Atom(107, np.nan, np.nan, "", np.nan, np.nan, 274),
1815
+ "HS": Atom(108, np.nan, np.nan, "", np.nan, np.nan, 269),
1816
+ "MT": Atom(109, np.nan, np.nan, "", np.nan, np.nan, 276),
1817
+ "DS": Atom(110, np.nan, np.nan, "", np.nan, np.nan, 281),
1818
+ "RG": Atom(111, np.nan, np.nan, "", np.nan, np.nan, 281),
1819
+ "CN": Atom(112, np.nan, np.nan, "", np.nan, np.nan, 285),
1820
+ "NH": Atom(113, np.nan, np.nan, "", np.nan, np.nan, 286),
1821
+ "FL": Atom(114, np.nan, np.nan, "", np.nan, np.nan, 289),
1822
+ "MC": Atom(115, np.nan, np.nan, "", np.nan, np.nan, 288),
1823
+ "LV": Atom(116, np.nan, np.nan, "", np.nan, np.nan, 293),
1824
+ "TS": Atom(117, np.nan, np.nan, "", np.nan, np.nan, 294),
1825
+ "OG": Atom(118, np.nan, np.nan, "", np.nan, np.nan, 294),
1826
+ }
1827
+
1828
+ def __getitem__(self, key: str):
1829
+ """
1830
+ Retrieve a value from the internal data using a given key.
1831
+
1832
+ Parameters
1833
+ ----------
1834
+ key : str
1835
+ Key to retrieve the corresponding value for.
1836
+
1837
+ Returns
1838
+ -------
1839
+ namedtuple
1840
+ The Atom tuple associated with the provided key.
1841
+ """
1842
+ return self._elements.get(key, self._default)
1843
+
1844
+
1845
+ def _format_string(string: str) -> str:
1846
+ """
1847
+ Formats a string by adding quotation marks if it contains white spaces.
1848
+
1849
+ Parameters
1850
+ ----------
1851
+ string : str
1852
+ Input string to be formatted.
1853
+
1854
+ Returns
1855
+ -------
1856
+ str
1857
+ Formatted string with added quotation marks if needed.
1858
+ """
1859
+ if " " in string:
1860
+ return f"'{string}'"
1861
+ # Occurs e.g. for C1' atoms. The trailing whitespace is necessary.
1862
+ if string.count("'") == 1:
1863
+ return f'"{string}"'
1864
+ return string