pytme 0.1.5__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. pytme-0.1.5.data/scripts/estimate_ram_usage.py +81 -0
  2. pytme-0.1.5.data/scripts/match_template.py +744 -0
  3. pytme-0.1.5.data/scripts/postprocess.py +279 -0
  4. pytme-0.1.5.data/scripts/preprocess.py +93 -0
  5. pytme-0.1.5.data/scripts/preprocessor_gui.py +729 -0
  6. pytme-0.1.5.dist-info/LICENSE +153 -0
  7. pytme-0.1.5.dist-info/METADATA +69 -0
  8. pytme-0.1.5.dist-info/RECORD +63 -0
  9. pytme-0.1.5.dist-info/WHEEL +5 -0
  10. pytme-0.1.5.dist-info/entry_points.txt +6 -0
  11. pytme-0.1.5.dist-info/top_level.txt +2 -0
  12. scripts/__init__.py +0 -0
  13. scripts/estimate_ram_usage.py +81 -0
  14. scripts/match_template.py +744 -0
  15. scripts/match_template_devel.py +788 -0
  16. scripts/postprocess.py +279 -0
  17. scripts/preprocess.py +93 -0
  18. scripts/preprocessor_gui.py +729 -0
  19. tme/__init__.py +6 -0
  20. tme/__version__.py +1 -0
  21. tme/analyzer.py +1144 -0
  22. tme/backends/__init__.py +134 -0
  23. tme/backends/cupy_backend.py +309 -0
  24. tme/backends/matching_backend.py +1154 -0
  25. tme/backends/npfftw_backend.py +763 -0
  26. tme/backends/pytorch_backend.py +526 -0
  27. tme/data/__init__.py +0 -0
  28. tme/data/c48n309.npy +0 -0
  29. tme/data/c48n527.npy +0 -0
  30. tme/data/c48n9.npy +0 -0
  31. tme/data/c48u1.npy +0 -0
  32. tme/data/c48u1153.npy +0 -0
  33. tme/data/c48u1201.npy +0 -0
  34. tme/data/c48u1641.npy +0 -0
  35. tme/data/c48u181.npy +0 -0
  36. tme/data/c48u2219.npy +0 -0
  37. tme/data/c48u27.npy +0 -0
  38. tme/data/c48u2947.npy +0 -0
  39. tme/data/c48u3733.npy +0 -0
  40. tme/data/c48u4749.npy +0 -0
  41. tme/data/c48u5879.npy +0 -0
  42. tme/data/c48u7111.npy +0 -0
  43. tme/data/c48u815.npy +0 -0
  44. tme/data/c48u83.npy +0 -0
  45. tme/data/c48u8649.npy +0 -0
  46. tme/data/c600v.npy +0 -0
  47. tme/data/c600vc.npy +0 -0
  48. tme/data/metadata.yaml +80 -0
  49. tme/data/quat_to_numpy.py +42 -0
  50. tme/data/scattering_factors.pickle +0 -0
  51. tme/density.py +2314 -0
  52. tme/extensions.cpython-311-darwin.so +0 -0
  53. tme/helpers.py +881 -0
  54. tme/matching_data.py +377 -0
  55. tme/matching_exhaustive.py +1553 -0
  56. tme/matching_memory.py +382 -0
  57. tme/matching_optimization.py +1123 -0
  58. tme/matching_utils.py +1180 -0
  59. tme/parser.py +429 -0
  60. tme/preprocessor.py +1291 -0
  61. tme/scoring.py +866 -0
  62. tme/structure.py +1428 -0
  63. tme/types.py +10 -0
tme/structure.py ADDED
@@ -0,0 +1,1428 @@
1
+ """ Implements class Structure to represent atomic structures.
2
+
3
+ Copyright (c) 2023 European Molecular Biology Laboratory
4
+
5
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ """
7
+ import warnings
8
+ from copy import deepcopy
9
+ from collections import namedtuple
10
+ from typing import List, Dict, Tuple
11
+ from itertools import groupby
12
+ from dataclasses import dataclass
13
+ from os.path import splitext, basename
14
+
15
+ import numpy as np
16
+
17
+ from .parser import PDBParser, MMCIFParser
18
+ from .matching_utils import (
19
+ rigid_transform,
20
+ _format_mmcif_colunns,
21
+ minimum_enclosing_box,
22
+ )
23
+ from .helpers import atom_profile
24
+ from .types import NDArray
25
+
26
+
27
+ @dataclass(repr=False)
28
+ class Structure:
29
+ """Represents atomic structures in accordance with the Protein Data Bank (PDB)
30
+ format specification.
31
+
32
+ Attributes
33
+ ----------
34
+ record_type : NDArray
35
+ Type of the record, e.g., ATOM, HETATM. Array shape = (n,)
36
+ atom_serial_number : NDArray
37
+ Serial number assigned to each atom. Array shape = (n,)
38
+ atom_name : NDArray
39
+ Standardized names for each atom. Array shape = (n,)
40
+ atom_coordinate : NDArray
41
+ The 3D Cartesian coordinates of each atom in x, y, z. Array shape = (n,3 )
42
+ alternate_location_indicator : NDArray
43
+ Indicator for alternate locations of an atom if it exists in multiple places.
44
+ Array shape = (n,)
45
+ residue_name : NDArray
46
+ Standard residue names where each atom belongs. Array shape = (n,)
47
+ chain_identifier : NDArray
48
+ Identifier for the chain where each atom is located. Array shape = (n,)
49
+ residue_sequence_number : NDArray
50
+ Sequence number of the residue in the protein chain for each atom.
51
+ Array shape = (n,)
52
+ code_for_residue_insertion : NDArray
53
+ Code to denote any residue insertion. Array shape = (n,)
54
+ occupancy : NDArray
55
+ Occupancy factor of each atom, indicating the fraction of time the atom
56
+ is located at its position. Array shape = (n,)
57
+ temperature_factor : NDArray
58
+ Measure of the atomic displacement or B-factor for each atom. Array shape = (n,)
59
+ segment_identifier : NDArray
60
+ Identifier for the segment where each atom belongs. Array shape = (n,)
61
+ element_symbol : NDArray
62
+ Atomic element symbol for each atom. Array shape = (n,)
63
+ charge : NDArray
64
+ Charge on the atom. Array shape = (n,)
65
+ details : dict
66
+ Any additional or auxiliary details. Array shape = (n,)
67
+
68
+ References
69
+ ----------
70
+ .. [1] https://www.cgl.ucsf.edu/chimera/docs/UsersGuide/tutorials/pdbintro.html
71
+ """
72
+
73
+ #: Return a numpy array with record types, e.g. ATOM, HETATM.
74
+ record_type: NDArray
75
+
76
+ #: Return a numpy array with serial number of each atom.
77
+ atom_serial_number: NDArray
78
+
79
+ #: Return a numpy array with name of each atom.
80
+ atom_name: NDArray
81
+
82
+ #: Return a numpy array with coordinates of each atom in x, y, z.
83
+ atom_coordinate: NDArray
84
+
85
+ #: Return a numpy array with alternate location indicates of each atom.
86
+ alternate_location_indicator: NDArray
87
+
88
+ #: Return a numpy array with originating residue names of each atom.
89
+ residue_name: NDArray
90
+
91
+ #: Return a numpy array with originating structure chain of each atom.
92
+ chain_identifier: NDArray
93
+
94
+ #: Return a numpy array with originating residue id of each atom.
95
+ residue_sequence_number: NDArray
96
+
97
+ #: Return a numpy array with insertion information d of each atom.
98
+ code_for_residue_insertion: NDArray
99
+
100
+ #: Return a numpy array with occupancy factors of each atom.
101
+ occupancy: NDArray
102
+
103
+ #: Return a numpy array with B-factors for each atom.
104
+ temperature_factor: NDArray
105
+
106
+ #: Return a numpy array with segment identifier for each atom.
107
+ segment_identifier: NDArray
108
+
109
+ #: Return a numpy array with element symbols of each atom.
110
+ element_symbol: NDArray
111
+
112
+ #: Return a numpy array with charges of each atom.
113
+ charge: NDArray
114
+
115
+ #: Returns a dictionary with class instance metadata.
116
+ details: dict
117
+
118
+ def __post_init__(self, *args, **kwargs):
119
+ """
120
+ Initialize the structure and populate header details.
121
+
122
+ Raises
123
+ ------
124
+ ValueError
125
+ If other NDArray attributes to not match the number of atoms.
126
+ If the shape of atom_coordinates and chain_identifier doesn't match.
127
+ """
128
+ self._elements = Elements()
129
+ self.details = self._populate_details(self.details)
130
+
131
+ n_atoms = self.atom_coordinate.shape[0]
132
+ for attribute in self.__dict__:
133
+ value = getattr(self, attribute)
134
+ if type(value) != np.ndarray:
135
+ continue
136
+ if value.shape[0] != n_atoms:
137
+ raise ValueError(
138
+ f"Expected shape of {attribute}: {n_atoms}, got {value.shape[0]}."
139
+ )
140
+
141
+ def __getitem__(self, indices: List[int]) -> "Structure":
142
+ """
143
+ Get a Structure instance for specified indices.
144
+
145
+ Parameters
146
+ ----------
147
+ indices : Union[int, bool, NDArray]
148
+ The indices to get.
149
+
150
+ Returns
151
+ -------
152
+ Structure
153
+ The Structure instance for the given indices.
154
+ """
155
+ if type(indices) in (int, bool):
156
+ indices = (indices,)
157
+
158
+ indices = np.asarray(indices)
159
+ attributes = (
160
+ "record_type",
161
+ "atom_serial_number",
162
+ "atom_name",
163
+ "atom_coordinate",
164
+ "alternate_location_indicator",
165
+ "residue_name",
166
+ "chain_identifier",
167
+ "residue_sequence_number",
168
+ "code_for_residue_insertion",
169
+ "occupancy",
170
+ "temperature_factor",
171
+ "segment_identifier",
172
+ "element_symbol",
173
+ "charge",
174
+ )
175
+ kwargs = {attr: getattr(self, attr)[indices] for attr in attributes}
176
+ ret = self.__class__(**kwargs, details={})
177
+ return ret
178
+
179
+ def __repr__(self):
180
+ """
181
+ Return a string representation of the Structure.
182
+
183
+ Returns
184
+ -------
185
+ str
186
+ The string representation.
187
+ """
188
+ unique_chains = "-".join(
189
+ [
190
+ ",".join([str(x) for x in entity])
191
+ for entity in self.details["unique_chains"]
192
+ ]
193
+ )
194
+ min_atom = np.min(self.atom_serial_number)
195
+ max_atom = np.max(self.atom_serial_number)
196
+ n_atom = self.atom_serial_number.size
197
+
198
+ min_residue = np.min(self.residue_sequence_number)
199
+ max_residue = np.max(self.residue_sequence_number)
200
+ n_residue = self.residue_sequence_number.size
201
+
202
+ repr_str = (
203
+ f"Structure object at {id(self)}\n"
204
+ f"Unique Chains: {unique_chains}, "
205
+ f"Atom Range: {min_atom}-{max_atom} [N = {n_atom}], "
206
+ f"Residue Range: {min_residue}-{max_residue} [N = {n_residue}]"
207
+ )
208
+ return repr_str
209
+
210
+ def get_chains(self) -> List[str]:
211
+ """
212
+ Returns a list of available chains.
213
+
214
+ Returns
215
+ -------
216
+ list
217
+ The list of available chains.
218
+ """
219
+ return list(self.details["chain_weight"].keys())
220
+
221
+ def copy(self) -> "Structure":
222
+ """
223
+ Returns a copy of the Structure instance.
224
+
225
+ Returns
226
+ -------
227
+ Structure
228
+ The copied Structure instance.
229
+ """
230
+ return deepcopy(self)
231
+
232
+ def _populate_details(self, details: Dict = {}) -> Dict:
233
+ """
234
+ Populate the details dictionary with the data from the Structure instance.
235
+
236
+ Parameters
237
+ ----------
238
+ details : dict, optional
239
+ The initial details dictionary, by default {}.
240
+
241
+ Returns
242
+ -------
243
+ dict
244
+ The populated details dictionary.
245
+ """
246
+ details["weight"] = np.sum(
247
+ [self._elements[atype].atomic_weight for atype in self.element_symbol]
248
+ )
249
+
250
+ label, idx, chain = np.unique(
251
+ self.chain_identifier, return_inverse=True, return_index=True
252
+ )
253
+ chain_weight = np.bincount(
254
+ chain,
255
+ [self._elements[atype].atomic_weight for atype in self.element_symbol],
256
+ )
257
+ labels = self.chain_identifier[idx]
258
+ details["chain_weight"] = {key: val for key, val in zip(labels, chain_weight)}
259
+
260
+ # Group non-unique chains in separate lists in details["unique_chains"]
261
+ details["unique_chains"], temp = [], {}
262
+ for chain_label in label:
263
+ index = len(details["unique_chains"])
264
+ chain_sequence = "".join(
265
+ [
266
+ str(y)
267
+ for y in self.element_symbol[
268
+ np.where(self.chain_identifier == chain_label)
269
+ ]
270
+ ]
271
+ )
272
+ if chain_sequence not in temp:
273
+ temp[chain_sequence] = index
274
+ details["unique_chains"].append([chain_label])
275
+ continue
276
+ idx = temp.get(chain_sequence)
277
+ details["unique_chains"][idx].append(chain_label)
278
+
279
+ filtered_data = [
280
+ (label, integer)
281
+ for label, integer in zip(
282
+ self.chain_identifier, self.residue_sequence_number
283
+ )
284
+ ]
285
+ filtered_data = sorted(filtered_data, key=lambda x: x[0])
286
+ details["chain_range"] = {}
287
+ for label, values in groupby(filtered_data, key=lambda x: x[0]):
288
+ values = [int(x[1]) for x in values]
289
+ details["chain_range"][label] = (min(values), max(values))
290
+
291
+ return details
292
+
293
+ @classmethod
294
+ def from_file(
295
+ cls,
296
+ filename: str,
297
+ keep_non_atom_records: bool = False,
298
+ filter_by_elements: set = None,
299
+ filter_by_residues: set = None,
300
+ ) -> "Structure":
301
+ """
302
+ Reads in an mmcif or pdb file and converts it into class instance.
303
+
304
+ Parameters
305
+ ----------
306
+ filename : str
307
+ Path to the mmcif or pdb file.
308
+ keep_non_atom_records : bool, optional
309
+ Wheter to keep residues that are not labelled ATOM.
310
+ filter_by_elements: set, optional
311
+ Which elements to keep. Default corresponds to all elements.
312
+ filter_by_residues: set, optional
313
+ Which residues to keep. Default corresponds to all residues.
314
+
315
+ Raises
316
+ ------
317
+ ValueError
318
+ If the extension is not '.pdb' or '.cif'.
319
+
320
+ Returns
321
+ -------
322
+ Structure
323
+ Read in structure file.
324
+ """
325
+ _, file_extension = splitext(basename(filename.upper()))
326
+ if file_extension == ".PDB":
327
+ func = cls._load_pdb
328
+ elif file_extension == ".CIF":
329
+ func = cls._load_mmcif
330
+ else:
331
+ raise NotImplementedError(
332
+ "Could not determine structure filetype from extension."
333
+ " Supported filetypes are mmcif (.cif) and pdb (.pdb)."
334
+ )
335
+ data = func(filename)
336
+
337
+ keep = np.ones(data["element_symbol"].size, dtype=bool)
338
+ if filter_by_elements:
339
+ keep = np.logical_and(
340
+ keep,
341
+ np.in1d(data["element_symbol"], np.array(list(filter_by_elements))),
342
+ )
343
+ if filter_by_residues:
344
+ keep = np.logical_and(
345
+ keep, np.in1d(data["residue_name"], np.array(list(filter_by_residues)))
346
+ )
347
+ if not keep_non_atom_records:
348
+ keep = np.logical_and(keep, data["record_type"] == "ATOM")
349
+
350
+ for key in data:
351
+ if key == "details":
352
+ continue
353
+ if type(data[key]) == np.ndarray:
354
+ data[key] = data[key][keep]
355
+ else:
356
+ data[key] = [x for x, flag in zip(data[key], keep) if flag]
357
+
358
+ data["details"]["filepath"] = filename
359
+
360
+ return cls(**data)
361
+
362
+ @staticmethod
363
+ def _load_mmcif(filename: str) -> Dict:
364
+ """
365
+ Parses a macromolecular Crystallographic Information File (mmCIF)
366
+ and returns the data in a dictionary format.
367
+
368
+ Parameters
369
+ ----------
370
+ filename : str
371
+ The filename of the mmCIF to load.
372
+
373
+ Returns
374
+ -------
375
+ dict
376
+ A dictionary of numpy arrays. Keys are the names of the PDB
377
+ coordinate section. In addition, some details about the parsed
378
+ structure are included. In case of conversion failure, the failing
379
+ attribute is set to 0 if its supposed to be an integer value.
380
+ """
381
+ result = MMCIFParser(filename)
382
+
383
+ atom_site_mapping = {
384
+ "record_type": ("group_PDB", str),
385
+ "atom_serial_number": ("id", int),
386
+ "atom_name": ("label_atom_id", str),
387
+ "alternate_location_indicator": ("label_alt_id", str),
388
+ "residue_name": ("label_comp_id", str),
389
+ # "chain_identifier": ("auth_asym_id", str),
390
+ "chain_identifier": ("label_asym_id", str),
391
+ "residue_sequence_number": ("label_seq_id", int),
392
+ "code_for_residue_insertion": ("pdbx_PDB_ins_code", str),
393
+ "occupancy": ("occupancy", float),
394
+ "temperature_factor": ("B_iso_or_equiv", float),
395
+ "segment_identifier": ("pdbx_PDB_model_num", str),
396
+ "element_symbol": ("type_symbol", str),
397
+ "charge": ("pdbx_formal_charge", str),
398
+ }
399
+
400
+ out = {}
401
+ for out_key, (atom_site_key, dtype) in atom_site_mapping.items():
402
+ out_data = [
403
+ x.strip() for x in result["atom_site"].get(atom_site_key, ["."])
404
+ ]
405
+ if dtype == int:
406
+ out_data = [0 if x == "." else int(x) for x in out_data]
407
+ try:
408
+ out[out_key] = np.asarray(out_data).astype(dtype)
409
+ except ValueError:
410
+ default = ["."] if dtype == str else 0
411
+ print(f"Converting {out_key} to {dtype} failed, set to {default}.")
412
+ out[out_key] = np.repeat(default, len(out_data)).astype(dtype)
413
+
414
+ number_entries = len(max(out.values(), key=len))
415
+ for key, value in out.items():
416
+ if value.size != 1:
417
+ continue
418
+ out[key] = np.repeat(value, number_entries // value.size)
419
+
420
+ out["details"] = {}
421
+ out["atom_coordinate"] = np.transpose(
422
+ np.array(
423
+ [
424
+ result["atom_site"]["Cartn_x"],
425
+ result["atom_site"]["Cartn_y"],
426
+ result["atom_site"]["Cartn_z"],
427
+ ],
428
+ dtype=np.float32,
429
+ )
430
+ )
431
+
432
+ detail_mapping = {
433
+ "resolution": ("em_3d_reconstruction", "resolution", np.nan),
434
+ "resolution_method": ("em_3d_reconstruction", "resolution_method", np.nan),
435
+ "method": ("exptl", "method", np.nan),
436
+ "electron_source": ("em_imaging", "electron_source", np.nan),
437
+ "illumination_mode": ("em_imaging", "illumination_mode", np.nan),
438
+ "microscope_model": ("em_imaging", "microscope_model", np.nan),
439
+ }
440
+ for out_key, (base_key, inner_key, default) in detail_mapping.items():
441
+ if base_key not in result:
442
+ continue
443
+ out["details"][out_key] = result[base_key].get(inner_key, default)
444
+
445
+ return out
446
+
447
+ @staticmethod
448
+ def _load_pdb(filename: str) -> Dict:
449
+ """
450
+ Parses a Protein Data Bank (PDB) file and returns the data
451
+ in a dictionary format.
452
+
453
+ Parameters
454
+ ----------
455
+ filename : str
456
+ The filename of the PDB file to load.
457
+
458
+ Returns
459
+ -------
460
+ dict
461
+ A dictionary of numpy arrays. Keys are the names of the PDB
462
+ coordinate section. In addition, some details about the parsed
463
+ structure are included. In case of conversion failure, the failing
464
+ attribute is set to 0 if its supposed to be an integer value.
465
+ """
466
+ result = PDBParser(filename)
467
+
468
+ atom_site_mapping = {
469
+ "record_type": ("record_type", str),
470
+ "atom_serial_number": ("atom_serial_number", int),
471
+ "atom_name": ("atom_name", str),
472
+ "alternate_location_indicator": ("alternate_location_indicator", str),
473
+ "residue_name": ("residue_name", str),
474
+ "chain_identifier": ("chain_identifier", str),
475
+ "residue_sequence_number": ("residue_sequence_number", int),
476
+ "code_for_residue_insertion": ("code_for_residue_insertion", str),
477
+ "occupancy": ("occupancy", float),
478
+ "temperature_factor": ("temperature_factor", float),
479
+ "segment_identifier": ("segment_identifier", str),
480
+ "element_symbol": ("element_symbol", str),
481
+ "charge": ("charge", str),
482
+ }
483
+
484
+ out = {"details": result["details"]}
485
+ for out_key, (inner_key, dtype) in atom_site_mapping.items():
486
+ out_data = [x.strip() for x in result[inner_key]]
487
+ if dtype == int:
488
+ out_data = [0 if x == "." else int(x) for x in out_data]
489
+ try:
490
+ out[out_key] = np.asarray(out_data).astype(dtype)
491
+ except ValueError:
492
+ default = "." if dtype == str else 0
493
+ print(
494
+ f"Converting {out_key} to {dtype} failed. Setting {out_key} to {default}."
495
+ )
496
+ out[out_key] = np.repeat(default, len(out_data)).astype(dtype)
497
+
498
+ out["atom_coordinate"] = np.array(result["atom_coordinate"], dtype=np.float32)
499
+
500
+ return out
501
+
502
+ def to_file(self, filename: str) -> None:
503
+ """
504
+ Writes the Structure instance data to a Protein Data Bank (PDB) or
505
+ macromolecular Crystallographic Information File (mmCIF) file depending
506
+ one whether filename ends with '.pdb' or '.cif'.
507
+
508
+ Raises
509
+ ------
510
+ ValueError
511
+ If the extension is not '.pdb' or '.cif'.
512
+
513
+ Parameters
514
+ ----------
515
+ filename : str
516
+ The filename of the file to write.
517
+ """
518
+ data_out = []
519
+ if np.any(np.vectorize(len)(self.chain_identifier) > 2):
520
+ warnings.warn("Chain identifiers longer than one will be shortened.")
521
+
522
+ _, file_extension = splitext(basename(filename.upper()))
523
+ if file_extension == ".PDB":
524
+ func = self._write_pdb
525
+ elif file_extension == ".CIF":
526
+ func = self._write_mmcif
527
+ else:
528
+ raise NotImplementedError(
529
+ "Could not determine structure filetype."
530
+ " Supported filetypes are mmcif (.cif) and pdb (.pdb)."
531
+ )
532
+
533
+ if self.atom_coordinate.shape[0] > 10**5 and func == self._write_pdb:
534
+ warnings.warn(
535
+ "The structure contains more than 100,000 atoms. Consider using mmcif."
536
+ )
537
+
538
+ with open(filename, mode="w", encoding="utf-8") as ofile:
539
+ ofile.writelines(func())
540
+
541
+ def _write_pdb(self) -> List[str]:
542
+ """
543
+ Returns a PDB string representation of the structure instance.
544
+
545
+ Returns
546
+ -------
547
+ list
548
+ List containing PDB file coordine lines.
549
+ """
550
+ data_out = []
551
+ for index in range(self.atom_coordinate.shape[0]):
552
+ x, y, z = self.atom_coordinate[index, :]
553
+ line = list(" " * 80)
554
+ line[0:6] = f"{self.record_type[index]:<6}"
555
+ line[6:11] = f"{self.atom_serial_number[index]:>5}"
556
+ line[12:16] = f"{self.atom_name[index]:<4}"
557
+ line[16] = f"{self.alternate_location_indicator[index]:<1}"
558
+ line[17:20] = f"{self.residue_name[index]:<3}"
559
+ line[21] = f"{self.chain_identifier[index][0]:<1}"
560
+ line[22:26] = f"{self.residue_sequence_number[index]:>4}"
561
+ line[26] = f"{self.code_for_residue_insertion[index]:<1}"
562
+ line[30:38] = f"{x:>8.3f}"
563
+ line[38:46] = f"{y:>8.3f}"
564
+ line[46:54] = f"{z:>8.3f}"
565
+ line[54:60] = f"{self.occupancy[index]:>6.2f}"
566
+ line[60:66] = f"{self.temperature_factor[index]:>6.2f}"
567
+ line[72:76] = f"{self.segment_identifier[index]:>4}"
568
+ line[76:78] = f"{self.element_symbol[index]:<2}"
569
+ line[78:80] = f"{self.charge[index]:>2}"
570
+ data_out.append("".join(line))
571
+ data_out.append("END")
572
+ data_out = "\n".join(data_out)
573
+ return data_out
574
+
575
+ def _write_mmcif(self) -> List[str]:
576
+ """
577
+ Returns a MMCIF string representation of the structure instance.
578
+
579
+ Returns
580
+ -------
581
+ list
582
+ List containing MMCIF file coordinate lines.
583
+ """
584
+ model_num, entity_id = 1, 1
585
+ data = {
586
+ "group_PDB": [],
587
+ "id": [],
588
+ "type_symbol": [],
589
+ "label_atom_id": [],
590
+ "label_alt_id": [],
591
+ "label_comp_id": [],
592
+ "label_asym_id": [],
593
+ "label_entity_id": [],
594
+ "label_seq_id": [],
595
+ "pdbx_PDB_ins_code": [],
596
+ "Cartn_x": [],
597
+ "Cartn_y": [],
598
+ "Cartn_z": [],
599
+ "occupancy": [],
600
+ "B_iso_or_equiv": [],
601
+ "pdbx_formal_charge": [],
602
+ "auth_seq_id": [],
603
+ "auth_comp_id": [],
604
+ "auth_asym_id": [],
605
+ "auth_atom_id": [],
606
+ "pdbx_PDB_model_num": [],
607
+ }
608
+
609
+ for index in range(self.atom_coordinate.shape[0]):
610
+ x, y, z = self.atom_coordinate[index, :]
611
+ data["group_PDB"].append(self.record_type[index])
612
+ data["id"].append(str(self.atom_serial_number[index]))
613
+ data["type_symbol"].append(self.element_symbol[index])
614
+ data["label_atom_id"].append(self.atom_name[index])
615
+ data["label_alt_id"].append(self.alternate_location_indicator[index])
616
+ data["label_comp_id"].append(self.residue_name[index])
617
+ data["label_asym_id"].append(self.chain_identifier[index][0])
618
+ data["label_entity_id"].append(str(entity_id))
619
+ data["label_seq_id"].append(str(self.residue_sequence_number[index]))
620
+ data["pdbx_PDB_ins_code"].append(self.code_for_residue_insertion[index])
621
+ data["Cartn_x"].append(f"{x:.3f}")
622
+ data["Cartn_y"].append(f"{y:.3f}")
623
+ data["Cartn_z"].append(f"{z:.3f}")
624
+ data["occupancy"].append(f"{self.occupancy[index]:.2f}")
625
+ data["B_iso_or_equiv"].append(f"{self.temperature_factor[index]:.2f}")
626
+ data["pdbx_formal_charge"].append(self.charge[index])
627
+ data["auth_seq_id"].append(str(self.residue_sequence_number[index]))
628
+ data["auth_comp_id"].append(self.residue_name[index])
629
+ data["auth_asym_id"].append(self.chain_identifier[index][0])
630
+ data["auth_atom_id"].append(self.atom_name[index])
631
+ data["pdbx_PDB_model_num"].append(str(model_num))
632
+
633
+ output_data = {"atom_site": data}
634
+ original_file = self.details.get("filepath", "")
635
+ try:
636
+ new_data = {k: v for k, v in MMCIFParser(original_file).items()}
637
+ index = self.atom_serial_number - 1
638
+ new_data["atom_site"] = {
639
+ k: [v[i] for i in index] for k, v in new_data["atom_site"].items()
640
+ }
641
+ new_data["atom_site"]["Cartn_x"] = data["Cartn_x"]
642
+ new_data["atom_site"]["Cartn_y"] = data["Cartn_y"]
643
+ new_data["atom_site"]["Cartn_z"] = data["Cartn_z"]
644
+ output_data = new_data
645
+ except Exception:
646
+ pass
647
+
648
+ ret = ""
649
+ for category, subdict in output_data.items():
650
+ ret += "#\n"
651
+ is_loop = isinstance(subdict[list(subdict.keys())[0]], list)
652
+ if not is_loop:
653
+ for k in subdict:
654
+ ret += f"_{category}.{k}\t{subdict[k]}\n"
655
+ else:
656
+ ret += "loop_\n"
657
+ ret += "".join([f"_{category}.{k}\n" for k in subdict])
658
+ padded_subdict = _format_mmcif_colunns(subdict)
659
+
660
+ data = [
661
+ "".join([str(x) for x in content])
662
+ for content in zip(*padded_subdict.values())
663
+ ]
664
+ ret += "\n".join([entry for entry in data]) + "\n"
665
+
666
+ return ret
667
+
668
+ def subset_by_chain(self, chain: str = None) -> "Structure":
669
+ """
670
+ Return a subset of the structure that contains only atoms belonging to
671
+ a specific chain. If no chain is specified, all chains are returned.
672
+
673
+ Parameters
674
+ ----------
675
+ chain : str, optional
676
+ The chain identifier. If multiple chains should be selected they need
677
+ to be a comma separated string, e.g. 'A,B,CE'. If chain None,
678
+ all chains are returned. Default is None.
679
+
680
+ Returns
681
+ -------
682
+ Structure
683
+ A subset of the original structure containing only the specified chain.
684
+ """
685
+ chain = np.unique(self.chain_identifier) if chain is None else chain.split(",")
686
+ keep = np.in1d(self.chain_identifier, chain)
687
+ return self[keep]
688
+
689
+ def subset_by_range(
690
+ self,
691
+ start: int,
692
+ stop: int,
693
+ chain: str = None,
694
+ ) -> "Structure":
695
+ """
696
+ Return a subset of the structure within a specific range of residues.
697
+
698
+ Parameters
699
+ ----------
700
+ start : int
701
+ The starting residue sequence number.
702
+
703
+ stop : int
704
+ The ending residue sequence number.
705
+
706
+ chain : str, optional
707
+ The chain identifier. If multiple chains should be selected they need
708
+ to be a comma separated string, e.g. 'A,B,CE'. If chain None,
709
+ all chains are returned. Default is None.
710
+
711
+ Returns
712
+ -------
713
+ Structure
714
+ A subset of the original structure within the specified residue range.
715
+ """
716
+ ret = self.subset_by_chain(chain=chain)
717
+ keep = np.logical_and(
718
+ ret.residue_sequence_number >= start, ret.residue_sequence_number <= stop
719
+ )
720
+ return ret[keep]
721
+
722
+ def center_of_mass(self) -> NDArray:
723
+ """
724
+ Calculate the center of mass of the structure.
725
+
726
+ Returns
727
+ -------
728
+ NDArray
729
+ The center of mass of the structure.
730
+ """
731
+ weights = [self._elements[atype].atomic_weight for atype in self.element_symbol]
732
+ return np.dot(self.atom_coordinate.T, weights) / np.sum(weights)
733
+
734
+ def rigid_transform(
735
+ self,
736
+ rotation_matrix: NDArray,
737
+ translation: NDArray,
738
+ use_geometric_center: bool = False,
739
+ ) -> "Structure":
740
+ """
741
+ Performs a rigid transform of internal structure coordinates.
742
+
743
+ Parameters
744
+ ----------
745
+ rotation_matrix : NDArray
746
+ The rotation matrix to apply to the coordinates.
747
+ translation : NDArray
748
+ The vector to translate the coordinates by.
749
+ use_geometric_center : bool, optional
750
+ Whether to use geometric or coordinate center.
751
+
752
+ Returns
753
+ -------
754
+ Structure
755
+ The transformed instance of :py:class:`tme.structure.Structure`.
756
+ """
757
+ out = np.empty_like(self.atom_coordinate.T)
758
+ rigid_transform(
759
+ coordinates=self.atom_coordinate.T,
760
+ rotation_matrix=rotation_matrix,
761
+ translation=translation,
762
+ out=out,
763
+ use_geometric_center=use_geometric_center,
764
+ )
765
+ ret = self.copy()
766
+ ret.atom_coordinate = out.T.copy()
767
+ return ret
768
+
769
+ def centered(self) -> Tuple["Structure", NDArray]:
770
+ """
771
+ Shifts the structure analogous to :py:meth:`tme.density.Density.centered`.
772
+
773
+ Returns
774
+ -------
775
+ Structure
776
+ A copy of the class instance whose data center of mass is in the
777
+ center of the data array.
778
+ NDArray
779
+ The coordinate translation.
780
+
781
+ See Also
782
+ --------
783
+ :py:meth:`tme.Density.centered`
784
+ """
785
+ center_of_mass = self.center_of_mass()
786
+ enclosing_box = minimum_enclosing_box(coordinates=self.atom_coordinate.T)
787
+ shift = np.subtract(np.divide(enclosing_box, 2), center_of_mass)
788
+
789
+ transformed_structure = self.rigid_transform(
790
+ translation=shift, rotation_matrix=np.eye(shift.size)
791
+ )
792
+
793
+ return transformed_structure, shift
794
+
795
+ def _coordinate_to_position(
796
+ self,
797
+ shape: Tuple[int],
798
+ sampling_rate: Tuple[float],
799
+ origin: Tuple[float],
800
+ ) -> (NDArray, Tuple[str], Tuple[int], float, Tuple[float]):
801
+ """
802
+ Converts coordinates to positions.
803
+
804
+ Parameters
805
+ ----------
806
+ shape : Tuple[int,]
807
+ The desired shape of the output array.
808
+
809
+ sampling_rate : float
810
+ The sampling rate of the output array in unit of self.atom_coordinate.
811
+
812
+ origin : Tuple[float,]
813
+ The origin of the coordinate system.
814
+ Returns
815
+ -------
816
+ Tuple[NDArray, List[str], Tuple[int, ], float, Tuple[float,]]
817
+ Returns positions, atom_types, shape, sampling_rate, and origin.
818
+ """
819
+ coordinates = self.atom_coordinate.copy()
820
+ atom_types = self.element_symbol.copy()
821
+
822
+ # positions are in x, y, z map is z, y, x
823
+ coordinates = coordinates[:, ::-1]
824
+
825
+ sampling_rate = 1 if sampling_rate is None else sampling_rate
826
+ adjust_origin = origin is not None and shape is None
827
+ origin = coordinates.min(axis=0) if origin is None else origin
828
+ positions = (coordinates - origin) / sampling_rate
829
+ positions = np.rint(positions).astype(int)
830
+
831
+ if adjust_origin:
832
+ left_shift = positions.min(axis=0)
833
+ positions -= left_shift
834
+ shape = positions.max(axis=0) + 1
835
+ origin = origin + np.multiply(left_shift, sampling_rate)
836
+
837
+ if shape is None:
838
+ shape = positions.max(axis=0) + 1
839
+
840
+ valid_positions = np.sum(
841
+ np.logical_and(positions < shape, positions >= 0), axis=1
842
+ )
843
+
844
+ positions = positions[valid_positions == positions.shape[1], :]
845
+ atom_types = atom_types[valid_positions == positions.shape[1]]
846
+
847
+ self.details["nAtoms_outOfBound"] = 0
848
+ if positions.shape[0] != coordinates.shape[0]:
849
+ out_of_bounds = coordinates.shape[0] - positions.shape[0]
850
+ print(f"{out_of_bounds}/{coordinates.shape[0]} atoms were out of bounds.")
851
+ self.details["nAtoms_outOfBound"] = out_of_bounds
852
+
853
+ return positions, atom_types, shape, sampling_rate, origin
854
+
855
+ def _position_to_vdw_sphere(
856
+ self,
857
+ positions: Tuple[float],
858
+ atoms: Tuple[str],
859
+ sampling_rate: Tuple[float],
860
+ volume: NDArray,
861
+ ) -> None:
862
+ """
863
+ Updates a volume with van der Waals spheres.
864
+
865
+ Parameters
866
+ ----------
867
+ positions : Tuple[float, float, float]
868
+ The positions of the atoms.
869
+
870
+ atoms : Tuple[str]
871
+ The types of the atoms.
872
+
873
+ sampling_rate : float
874
+ The desired sampling rate in unit of self.atom_coordinate of the
875
+ output array.
876
+
877
+ volume : NDArray
878
+ The volume to update.
879
+ """
880
+ index_dict, vdw_rad, shape = {}, {}, volume.shape
881
+ for atom_index, atom_position in enumerate(positions):
882
+ atom_type = atoms[atom_index]
883
+ if atom_type not in index_dict.keys():
884
+ atom_vdwr = np.ceil(
885
+ np.divide(self._elements[atom_type].vdwr, (sampling_rate * 100))
886
+ ).astype(int)
887
+
888
+ vdw_rad[atom_type] = atom_vdwr
889
+ atom_slice = tuple(slice(-k, k + 1) for k in atom_vdwr)
890
+ distances = np.linalg.norm(
891
+ np.divide(
892
+ np.mgrid[atom_slice],
893
+ atom_vdwr.reshape((-1,) + (1,) * volume.ndim),
894
+ ),
895
+ axis=0,
896
+ )
897
+ index_dict[atom_type] = (distances <= 1).astype(volume.dtype)
898
+
899
+ footprint = index_dict[atom_type]
900
+ start = np.maximum(np.subtract(atom_position, vdw_rad[atom_type]), 0)
901
+ stop = np.minimum(np.add(atom_position, vdw_rad[atom_type]) + 1, shape)
902
+ volume_slice = tuple(slice(*coord) for coord in zip(start, stop))
903
+
904
+ start_index = np.maximum(-np.subtract(atom_position, vdw_rad[atom_type]), 0)
905
+ stop_index = np.add(
906
+ footprint.shape,
907
+ np.minimum(
908
+ np.subtract(shape, np.add(atom_position, vdw_rad[atom_type]) + 1), 0
909
+ ),
910
+ )
911
+ index_slice = tuple(slice(*coord) for coord in zip(start_index, stop_index))
912
+ volume[volume_slice] += footprint[index_slice]
913
+
914
+ def _position_to_scattering_factors(
915
+ self,
916
+ positions: NDArray,
917
+ atoms: NDArray,
918
+ sampling_rate: NDArray,
919
+ volume: NDArray,
920
+ lowpass_filter: bool = True,
921
+ downsampling_factor: float = 1.35,
922
+ source: str = "peng1995",
923
+ ) -> None:
924
+ """
925
+ Updates a volume with scattering factors.
926
+
927
+ Parameters
928
+ ----------
929
+ positions : NDArray
930
+ The positions of the atoms.
931
+ atoms : NDArray
932
+ Element symbols.
933
+ sampling_rate : float
934
+ Sampling rate that was used to convert coordinates to positions.
935
+ volume : NDArray
936
+ The volume to update.
937
+ lowpass_filter : NDArray
938
+ Whether the scattering factors hsould be lowpass filtered.
939
+ downsampling_factor : NDArray
940
+ Downsampling factor for scattering factor computation.
941
+ source : str
942
+ Which scattering factors to use
943
+
944
+ Reference
945
+ ---------
946
+ https://github.com/I2PC/xmipp.
947
+ """
948
+ scattering_profiles, shape = dict(), volume.shape
949
+ for atom_index, point in enumerate(positions):
950
+ if atoms[atom_index] not in scattering_profiles:
951
+ spline = atom_profile(
952
+ atom=atoms[atom_index],
953
+ M=downsampling_factor,
954
+ method=source,
955
+ lfilter=lowpass_filter,
956
+ )
957
+ scattering_profiles.update({atoms[atom_index]: spline})
958
+
959
+ atomic_radius = np.divide(
960
+ self._elements[atoms[atom_index]].vdwr, sampling_rate * 100
961
+ )
962
+ starts = np.maximum(np.ceil(point - atomic_radius), 0).astype(int)
963
+ stops = np.minimum(np.floor(point + atomic_radius), shape).astype(int)
964
+
965
+ grid_index = np.meshgrid(
966
+ *[range(start, stop) for start, stop in zip(starts, stops)]
967
+ )
968
+ distances = np.einsum(
969
+ "aijk->ijk",
970
+ np.array([(grid_index[i] - point[i]) ** 2 for i in range(len(point))]),
971
+ dtype=np.float64,
972
+ )
973
+ distances = np.sqrt(distances)
974
+ if not len(distances):
975
+ grid_index, distances = point, 0
976
+ np.add.at(
977
+ volume,
978
+ tuple(grid_index),
979
+ scattering_profiles[atoms[atom_index]](distances),
980
+ )
981
+
982
+ def _get_atom_weights(
983
+ self, atoms: Tuple[str] = None, weight_type: str = "atomic_weight"
984
+ ) -> Tuple[float]:
985
+ """
986
+ Returns weights of individual atoms according to a specified weight type.
987
+
988
+ Parameters
989
+ ----------
990
+ atoms : Tuple of strings, optional
991
+ The atoms to get the weights for. If None, weights for all atoms
992
+ are used. Default is None.
993
+
994
+ weight_type : str, optional
995
+ The type of weights to return. This can either be 'atomic_weight',
996
+ 'atomic_number', or 'van_der_waals_radius'. Default is 'atomic_weight'.
997
+
998
+ Returns
999
+ -------
1000
+ List[float]
1001
+ A list containing the weights of the atoms.
1002
+ """
1003
+ atoms = self.element_symbol if atoms is None else atoms
1004
+ match weight_type:
1005
+ case "atomic_weight":
1006
+ weight = [self._elements[atom].atomic_weight for atom in atoms]
1007
+ case "atomic_number":
1008
+ weight = [self._elements[atom].atomic_number for atom in atoms]
1009
+ case _:
1010
+ raise NotImplementedError(
1011
+ "weight_type can either be 'atomic_weight' or 'atomic_number'"
1012
+ )
1013
+ return weight
1014
+
1015
+ def to_volume(
1016
+ self,
1017
+ shape: Tuple[int] = None,
1018
+ sampling_rate: NDArray = None,
1019
+ origin: Tuple[float] = None,
1020
+ chain: str = None,
1021
+ weight_type: str = "atomic_weight",
1022
+ scattering_args: Dict = dict(),
1023
+ ) -> Tuple[NDArray, Tuple[int], NDArray]:
1024
+ """
1025
+ Converts atom coordinates of shape [n x 3] x, y, z to a volume with
1026
+ index z, y, x.
1027
+
1028
+ Parameters
1029
+ ----------
1030
+ shape : Tuple[int, ...], optional
1031
+ Desired shape of the output array. If shape is given its expected to be
1032
+ in z, y, x form.
1033
+ sampling_rate : float, optional
1034
+ Sampling rate of the output array in the unit of self.atom_coordinate
1035
+ origin : Tuple[float, ...], optional
1036
+ Origin of the coordinate system. If origin is given its expected to be
1037
+ in z, y, x form.
1038
+ chain : str, optional
1039
+ The chain identifier. If multiple chains should be selected they need
1040
+ to be a comma separated string, e.g. 'A,B,CE'. If chain None,
1041
+ all chains are returned. Default is None.
1042
+ weight_type : str, optional
1043
+ Which weight should be given to individual atoms.
1044
+ scattering_args : dict, optional
1045
+ Additional arguments for scattering factor computation.
1046
+
1047
+ Returns
1048
+ -------
1049
+ Tuple[NDArray, Tuple[int], NDArray]
1050
+ The volume, its origin and the voxel size in Ångstrom.
1051
+ """
1052
+ _weight_types = {
1053
+ "atomic_weight",
1054
+ "atomic_number",
1055
+ "van_der_waals_radius",
1056
+ "scattering_factors",
1057
+ "lowpass_scattering_factors",
1058
+ }
1059
+ _weight_string = ",".join([f"'{x}'" for x in _weight_types])
1060
+ if weight_type not in _weight_types:
1061
+ raise NotImplementedError(f"weight_type needs to be in {_weight_string}")
1062
+
1063
+ if sampling_rate is None:
1064
+ sampling_rate = np.ones(self.atom_coordinate.shape[1])
1065
+ sampling_rate = np.array(sampling_rate)
1066
+ if sampling_rate.size == 1:
1067
+ sampling_rate = np.repeat(sampling_rate, self.atom_coordinate.shape[1])
1068
+ elif sampling_rate.size != self.atom_coordinate.shape[1]:
1069
+ raise ValueError(
1070
+ "sampling_rate should either be single value of array with"
1071
+ f"size {self.atom_coordinate.shape[1]}."
1072
+ )
1073
+ if "source" not in scattering_args:
1074
+ scattering_args["source"] = "peng1995"
1075
+
1076
+ temp = self.subset_by_chain(chain=chain)
1077
+
1078
+ positions, atoms, shape, sampling_rate, origin = temp._coordinate_to_position(
1079
+ shape=shape, sampling_rate=sampling_rate, origin=origin
1080
+ )
1081
+ volume = np.zeros(shape, dtype=np.float32)
1082
+ if weight_type in ("atomic_weight", "atomic_number"):
1083
+ weights = temp._get_atom_weights(atoms=atoms, weight_type=weight_type)
1084
+ np.add.at(volume, tuple(positions.T), weights)
1085
+ elif weight_type == "van_der_waals_radius":
1086
+ self._position_to_vdw_sphere(positions, atoms, sampling_rate, volume)
1087
+ elif weight_type == "scattering_factors":
1088
+ self._position_to_scattering_factors(
1089
+ positions,
1090
+ atoms,
1091
+ sampling_rate,
1092
+ volume,
1093
+ lowpass_filter=False,
1094
+ **scattering_args,
1095
+ )
1096
+ elif weight_type == "lowpass_scattering_factors":
1097
+ self._position_to_scattering_factors(
1098
+ positions,
1099
+ atoms,
1100
+ sampling_rate,
1101
+ volume,
1102
+ lowpass_filter=True,
1103
+ **scattering_args,
1104
+ )
1105
+
1106
+ self.details.update(temp.details)
1107
+ return volume, origin, sampling_rate
1108
+
1109
+ @classmethod
1110
+ def compare_structures(
1111
+ cls,
1112
+ structure1: "Structure",
1113
+ structure2: "Structure",
1114
+ origin: NDArray = None,
1115
+ sampling_rate: float = None,
1116
+ weighted: bool = False,
1117
+ ) -> float:
1118
+ """
1119
+ Compute root mean square deviation (RMSD) between two structures.
1120
+
1121
+ Both structures need to have the same number of atoms. In practice, this means
1122
+ that *structure2* is a transformed version of *structure1*
1123
+
1124
+ Parameters
1125
+ ----------
1126
+ structure1 : Structure
1127
+ Structure 1.
1128
+
1129
+ structure2 : Structure
1130
+ Structure 2.
1131
+
1132
+ origin : NDArray, optional
1133
+ Origin of the structure coordinate system.
1134
+
1135
+ sampling_rate : float, optional
1136
+ Sampling rate if discretized on a grid in the unit of self.atom_coordinate.
1137
+
1138
+ weighted : bool, optional
1139
+ Whether atoms should be weighted by their atomic weight.
1140
+
1141
+ Returns
1142
+ -------
1143
+ float
1144
+ Root Mean Square Deviation (RMSD)
1145
+ """
1146
+ if origin is None:
1147
+ origin = np.zeros(structure1.atom_coordinate.shape[1])
1148
+
1149
+ coordinates1 = structure1.atom_coordinate
1150
+ coordinates2 = structure2.atom_coordinate
1151
+ atoms1, atoms2 = structure1.element_symbol, structure2.element_symbol
1152
+ if sampling_rate is not None:
1153
+ coordinates1 = np.rint((coordinates1 - origin) / sampling_rate).astype(int)
1154
+ coordinates2 = np.rint((coordinates2 - origin) / sampling_rate).astype(int)
1155
+
1156
+ weights1 = np.array(structure1._get_atom_weights(atoms=atoms1))
1157
+ weights2 = np.array(structure2._get_atom_weights(atoms=atoms2))
1158
+ if not weighted:
1159
+ weights1 = np.ones_like(weights1)
1160
+ weights2 = np.ones_like(weights2)
1161
+
1162
+ if not np.allclose(coordinates1.shape, coordinates2.shape):
1163
+ raise ValueError(
1164
+ "Input structures need to have the same number of coordinates."
1165
+ )
1166
+ if not np.allclose(weights1.shape, weights2.shape):
1167
+ raise ValueError("Input structures need to have the same number of atoms.")
1168
+
1169
+ squared_diff = np.sum(np.square(coordinates1 - coordinates2), axis=1)
1170
+ weighted_quared_diff = squared_diff * ((weights1 + weights2) / 2)
1171
+ rmsd = np.sqrt(np.mean(weighted_quared_diff))
1172
+
1173
+ return rmsd
1174
+
1175
+ @classmethod
1176
+ def align_structures(
1177
+ cls,
1178
+ structure1: "Structure",
1179
+ structure2: "Structure",
1180
+ origin: NDArray = None,
1181
+ sampling_rate: float = None,
1182
+ weighted: bool = False,
1183
+ ) -> Tuple["Structure", float]:
1184
+ """
1185
+ Align the atom coordinates of structure2 to structure1 using
1186
+ the Kabsch algorithm.
1187
+
1188
+ Both structures need to have the same number of atoms. In practice, this means
1189
+ that *structure2* is a subset of *structure1*
1190
+
1191
+ Parameters
1192
+ ----------
1193
+ structure1 : Structure
1194
+ Structure 1.
1195
+
1196
+ structure2 : Structure
1197
+ Structure 2.
1198
+
1199
+ origin : NDArray, optional
1200
+ Origin of the structure coordinate system.
1201
+
1202
+ sampling_rate : float, optional
1203
+ Voxel size if discretized on a grid.
1204
+
1205
+ weighted : bool, optional
1206
+ Whether atoms should be weighted by their atomic weight.
1207
+
1208
+ Returns
1209
+ -------
1210
+ Structure
1211
+ *structure2* aligned to *structure1*.
1212
+ float
1213
+ Root Mean Square Error (RMSE)
1214
+ """
1215
+ if origin is None:
1216
+ origin = np.minimum(
1217
+ structure1.atom_coordinate.min(axis=0),
1218
+ structure2.atom_coordinate.min(axis=0),
1219
+ ).astype(int)
1220
+
1221
+ initial_rmsd = cls.compare_structures(
1222
+ structure1=structure1,
1223
+ structure2=structure2,
1224
+ origin=origin,
1225
+ sampling_rate=sampling_rate,
1226
+ weighted=weighted,
1227
+ )
1228
+
1229
+ reference = structure1.atom_coordinate.copy()
1230
+ query = structure2.atom_coordinate.copy()
1231
+ if sampling_rate is not None:
1232
+ reference, atoms1, shape, _, _ = structure1._coordinate_to_position(
1233
+ shape=None, sampling_rate=sampling_rate, origin=origin
1234
+ )
1235
+ query, atoms2, shape, _, _ = structure2._coordinate_to_position(
1236
+ shape=None, sampling_rate=sampling_rate, origin=origin
1237
+ )
1238
+
1239
+ reference_mean = reference.mean(axis=0)
1240
+ query_mean = query.mean(axis=0)
1241
+
1242
+ reference = reference - reference_mean
1243
+ query = query - query_mean
1244
+
1245
+ corr = np.dot(query.T, reference)
1246
+ U, S, Vh = np.linalg.svd(corr)
1247
+
1248
+ rotation = np.dot(Vh.T, U.T).T
1249
+ if np.linalg.det(rotation) < 0:
1250
+ Vh[2, :] *= -1
1251
+ rotation = np.dot(Vh.T, U.T).T
1252
+
1253
+ translation = reference_mean - np.dot(query_mean, rotation)
1254
+
1255
+ temp = structure1.copy()
1256
+ temp.atom_coordinate = reference + reference_mean
1257
+ ret = structure2.copy()
1258
+ ret.atom_coordinate = np.dot(query + query_mean, rotation) + translation
1259
+
1260
+ final_rmsd = cls.compare_structures(
1261
+ structure1=temp,
1262
+ structure2=ret,
1263
+ origin=origin,
1264
+ sampling_rate=None,
1265
+ weighted=weighted,
1266
+ )
1267
+
1268
+ print(f"Initial RMSD: {initial_rmsd:.5f} - Final RMSD: {final_rmsd:.5f}")
1269
+
1270
+ return ret, final_rmsd
1271
+
1272
+
1273
+ @dataclass(frozen=True, repr=True)
1274
+ class Elements:
1275
+ """
1276
+ Lookup table containing information on chemical elements.
1277
+ """
1278
+
1279
+ Atom = namedtuple(
1280
+ "Atom",
1281
+ [
1282
+ "atomic_number",
1283
+ "atomic_radius",
1284
+ "lattice_constant",
1285
+ "lattice_structure",
1286
+ "vdwr",
1287
+ "covalent_radius_bragg",
1288
+ "atomic_weight",
1289
+ ],
1290
+ )
1291
+ _default = Atom(0, 0, 0, "Atom does not exist in ressource.", 0, 0, 0)
1292
+ _elements = {
1293
+ "H": Atom(1, 25, 3.75, "HEX", 110, np.nan, 1.008),
1294
+ "HE": Atom(2, 120, 3.57, "HEX", 140, np.nan, 4.002602),
1295
+ "LI": Atom(3, 145, 3.49, "BCC", 182, 150, 6.94),
1296
+ "BE": Atom(4, 105, 2.29, "HEX", 153, 115, 9.0121831),
1297
+ "B": Atom(5, 85, 8.73, "TET", 192, np.nan, 10.81),
1298
+ "C": Atom(6, 70, 3.57, "DIA", 170, 77, 12.011),
1299
+ "N": Atom(7, 65, 4.039, "HEX", 155, 65, 14.007),
1300
+ "O": Atom(8, 60, 6.83, "CUB", 152, 65, 15.999),
1301
+ "F": Atom(9, 50, np.nan, "MCL", 147, 67, 18.998403163),
1302
+ "NE": Atom(10, 160, 4.43, "FCC", 154, np.nan, 20.1797),
1303
+ "NA": Atom(11, 180, 4.23, "BCC", 227, 177, 22.98976928),
1304
+ "MG": Atom(12, 150, 3.21, "HEX", 173, 142, 24.305),
1305
+ "AL": Atom(13, 125, 4.05, "FCC", 184, 135, 26.9815385),
1306
+ "SI": Atom(14, 110, 5.43, "DIA", 210, 117, 28.085),
1307
+ "P": Atom(15, 100, 7.17, "CUB", 180, np.nan, 30.973761998),
1308
+ "S": Atom(16, 100, 10.47, "ORC", 180, 102, 32.06),
1309
+ "CL": Atom(17, 100, 6.24, "ORC", 175, 105, 35.45),
1310
+ "AR": Atom(18, 71, 5.26, "FCC", 188, np.nan, 39.948),
1311
+ "K": Atom(19, 220, 5.23, "BCC", 275, 207, 39.0983),
1312
+ "CA": Atom(20, 180, 5.58, "FCC", 231, 170, 40.078),
1313
+ "SC": Atom(21, 160, 3.31, "HEX", 215, np.nan, 44.955908),
1314
+ "TI": Atom(22, 140, 2.95, "HEX", 211, 140, 47.867),
1315
+ "V": Atom(23, 135, 3.02, "BCC", 207, np.nan, 50.9415),
1316
+ "CR": Atom(24, 140, 2.88, "BCC", 206, 140, 51.9961),
1317
+ "MN": Atom(25, 140, 8.89, "CUB", 205, 147, 54.938044),
1318
+ "FE": Atom(26, 140, 2.87, "BCC", 204, 140, 55.845),
1319
+ "CO": Atom(27, 135, 2.51, "HEX", 200, 137, 58.933194),
1320
+ "NI": Atom(28, 135, 3.52, "FCC", 197, 135, 58.6934),
1321
+ "CU": Atom(29, 135, 3.61, "FCC", 196, 137, 63.546),
1322
+ "ZN": Atom(30, 135, 2.66, "HEX", 201, 132, 65.38),
1323
+ "GA": Atom(31, 130, 4.51, "ORC", 187, np.nan, 69.723),
1324
+ "GE": Atom(32, 125, 5.66, "DIA", 211, np.nan, 72.63),
1325
+ "AS": Atom(33, 115, 4.13, "RHL", 185, 126, 74.921595),
1326
+ "SE": Atom(34, 115, 4.36, "HEX", 190, 117, 78.971),
1327
+ "BR": Atom(35, 115, 6.67, "ORC", 185, 119, 79.904),
1328
+ "KR": Atom(36, np.nan, 5.72, "FCC", 202, np.nan, 83.798),
1329
+ "RB": Atom(37, 235, 5.59, "BCC", 303, 225, 85.4678),
1330
+ "SR": Atom(38, 200, 6.08, "FCC", 249, 195, 87.62),
1331
+ "Y": Atom(39, 180, 3.65, "HEX", 232, np.nan, 88.90584),
1332
+ "ZR": Atom(40, 155, 3.23, "HEX", 223, np.nan, 91.224),
1333
+ "NB": Atom(41, 145, 3.3, "BCC", 218, np.nan, 92.90637),
1334
+ "MO": Atom(42, 145, 3.15, "BCC", 217, np.nan, 95.95),
1335
+ "TC": Atom(43, 135, 2.74, "HEX", 216, np.nan, 97.90721),
1336
+ "RU": Atom(44, 130, 2.7, "HEX", 213, np.nan, 101.07),
1337
+ "RH": Atom(45, 135, 3.8, "FCC", 210, np.nan, 102.9055),
1338
+ "PD": Atom(46, 140, 3.89, "FCC", 210, np.nan, 106.42),
1339
+ "AG": Atom(47, 160, 4.09, "FCC", 211, 177, 107.8682),
1340
+ "CD": Atom(48, 155, 2.98, "HEX", 218, 160, 112.414),
1341
+ "IN": Atom(49, 155, 4.59, "TET", 193, np.nan, 114.818),
1342
+ "SN": Atom(50, 145, 5.82, "TET", 217, 140, 118.71),
1343
+ "SB": Atom(51, 145, 4.51, "RHL", 206, 140, 121.76),
1344
+ "TE": Atom(52, 140, 4.45, "HEX", 206, 133, 127.6),
1345
+ "I": Atom(53, 140, 7.72, "ORC", 198, 140, 126.90447),
1346
+ "XE": Atom(54, np.nan, 6.2, "FCC", 216, np.nan, 131.293),
1347
+ "CS": Atom(55, 260, 6.05, "BCC", 343, 237, 132.90545196),
1348
+ "BA": Atom(56, 215, 5.02, "BCC", 268, 210, 137.327),
1349
+ "LA": Atom(57, 195, 3.75, "HEX", 243, np.nan, 138.90547),
1350
+ "CE": Atom(58, 185, 5.16, "FCC", 242, np.nan, 140.116),
1351
+ "PR": Atom(59, 185, 3.67, "HEX", 240, np.nan, 140.90766),
1352
+ "ND": Atom(60, 185, 3.66, "HEX", 239, np.nan, 144.242),
1353
+ "PM": Atom(61, 185, np.nan, "", 238, np.nan, 144.91276),
1354
+ "SM": Atom(62, 185, 9, "RHL", 236, np.nan, 150.36),
1355
+ "EU": Atom(63, 185, 4.61, "BCC", 235, np.nan, 151.964),
1356
+ "GD": Atom(64, 180, 3.64, "HEX", 234, np.nan, 157.25),
1357
+ "TB": Atom(65, 175, 3.6, "HEX", 233, np.nan, 158.92535),
1358
+ "DY": Atom(66, 175, 3.59, "HEX", 231, np.nan, 162.5),
1359
+ "HO": Atom(67, 175, 3.58, "HEX", 230, np.nan, 164.93033),
1360
+ "ER": Atom(68, 175, 3.56, "HEX", 229, np.nan, 167.259),
1361
+ "TM": Atom(69, 175, 3.54, "HEX", 227, np.nan, 168.93422),
1362
+ "YB": Atom(70, 175, 5.49, "FCC", 226, np.nan, 173.045),
1363
+ "LU": Atom(71, 175, 3.51, "HEX", 224, np.nan, 174.9668),
1364
+ "HF": Atom(72, 155, 3.2, "HEX", 223, np.nan, 178.49),
1365
+ "TA": Atom(73, 145, 3.31, "BCC", 222, np.nan, 180.94788),
1366
+ "W": Atom(74, 135, 3.16, "BCC", 218, np.nan, 183.84),
1367
+ "RE": Atom(75, 135, 2.76, "HEX", 216, np.nan, 186.207),
1368
+ "OS": Atom(76, 130, 2.74, "HEX", 216, np.nan, 190.23),
1369
+ "IR": Atom(77, 135, 3.84, "FCC", 213, np.nan, 192.217),
1370
+ "PT": Atom(78, 135, 3.92, "FCC", 213, np.nan, 195.084),
1371
+ "AU": Atom(79, 135, 4.08, "FCC", 214, np.nan, 196.966569),
1372
+ "HG": Atom(80, 150, 2.99, "RHL", 223, np.nan, 200.592),
1373
+ "TL": Atom(81, 190, 3.46, "HEX", 196, 190, 204.38),
1374
+ "PB": Atom(82, 180, 4.95, "FCC", 202, np.nan, 207.2),
1375
+ "BI": Atom(83, 160, 4.75, "RHL", 207, 148, 208.9804),
1376
+ "PO": Atom(84, 190, 3.35, "SC", 197, np.nan, 209),
1377
+ "AT": Atom(85, np.nan, np.nan, "", 202, np.nan, 210),
1378
+ "RN": Atom(86, np.nan, np.nan, "FCC", 220, np.nan, 222),
1379
+ "FR": Atom(87, np.nan, np.nan, "BCC", 348, np.nan, 223),
1380
+ "RA": Atom(88, 215, np.nan, "", 283, np.nan, 226),
1381
+ "AC": Atom(89, 195, 5.31, "FCC", 247, np.nan, 227),
1382
+ "TH": Atom(90, 180, 5.08, "FCC", 245, np.nan, 232.0377),
1383
+ "PA": Atom(91, 180, 3.92, "TET", 243, np.nan, 231.03588),
1384
+ "U": Atom(92, 175, 2.85, "ORC", 241, np.nan, 238.02891),
1385
+ "NP": Atom(93, 175, 4.72, "ORC", 239, np.nan, 237),
1386
+ "PU": Atom(94, 175, np.nan, "MCL", 243, np.nan, 244),
1387
+ "AM": Atom(95, 175, np.nan, "", 244, np.nan, 243),
1388
+ "CM": Atom(96, np.nan, np.nan, "", 245, np.nan, 247),
1389
+ "BK": Atom(97, np.nan, np.nan, "", 244, np.nan, 247),
1390
+ "CF": Atom(98, np.nan, np.nan, "", 245, np.nan, 251),
1391
+ "ES": Atom(99, np.nan, np.nan, "", 245, np.nan, 252),
1392
+ "FM": Atom(100, np.nan, np.nan, "", 245, np.nan, 257),
1393
+ "MD": Atom(101, np.nan, np.nan, "", 246, np.nan, 258),
1394
+ "NO": Atom(102, np.nan, np.nan, "", 246, np.nan, 259),
1395
+ "LR": Atom(103, np.nan, np.nan, "", 246, np.nan, 262),
1396
+ "RF": Atom(104, np.nan, np.nan, "", np.nan, np.nan, 267),
1397
+ "DB": Atom(105, np.nan, np.nan, "", np.nan, np.nan, 268),
1398
+ "SG": Atom(106, np.nan, np.nan, "", np.nan, np.nan, 271),
1399
+ "BH": Atom(107, np.nan, np.nan, "", np.nan, np.nan, 274),
1400
+ "HS": Atom(108, np.nan, np.nan, "", np.nan, np.nan, 269),
1401
+ "MT": Atom(109, np.nan, np.nan, "", np.nan, np.nan, 276),
1402
+ "DS": Atom(110, np.nan, np.nan, "", np.nan, np.nan, 281),
1403
+ "RG": Atom(111, np.nan, np.nan, "", np.nan, np.nan, 281),
1404
+ "CN": Atom(112, np.nan, np.nan, "", np.nan, np.nan, 285),
1405
+ "NH": Atom(113, np.nan, np.nan, "", np.nan, np.nan, 286),
1406
+ "FL": Atom(114, np.nan, np.nan, "", np.nan, np.nan, 289),
1407
+ "MC": Atom(115, np.nan, np.nan, "", np.nan, np.nan, 288),
1408
+ "LV": Atom(116, np.nan, np.nan, "", np.nan, np.nan, 293),
1409
+ "TS": Atom(117, np.nan, np.nan, "", np.nan, np.nan, 294),
1410
+ "OG": Atom(118, np.nan, np.nan, "", np.nan, np.nan, 294),
1411
+ }
1412
+
1413
+ def __getitem__(self, key: str):
1414
+ """
1415
+ Retrieve a value from the internal data using a given key.
1416
+
1417
+ Parameters
1418
+ ----------
1419
+ key : str
1420
+ The key to use for retrieving the corresponding value from
1421
+ the internal data.
1422
+
1423
+ Returns
1424
+ -------
1425
+ value
1426
+ The value associated with the provided key in the internal data.
1427
+ """
1428
+ return self._elements.get(key, self._default)