pytme 0.3.1.post1__cp311-cp311-macosx_15_0_arm64.whl → 0.3.2.dev0__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. pytme-0.3.2.dev0.data/scripts/estimate_ram_usage.py +97 -0
  2. {pytme-0.3.1.post1.data → pytme-0.3.2.dev0.data}/scripts/match_template.py +213 -196
  3. {pytme-0.3.1.post1.data → pytme-0.3.2.dev0.data}/scripts/postprocess.py +40 -78
  4. {pytme-0.3.1.post1.data → pytme-0.3.2.dev0.data}/scripts/preprocess.py +4 -5
  5. {pytme-0.3.1.post1.data → pytme-0.3.2.dev0.data}/scripts/preprocessor_gui.py +50 -103
  6. {pytme-0.3.1.post1.data → pytme-0.3.2.dev0.data}/scripts/pytme_runner.py +46 -69
  7. {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dev0.dist-info}/METADATA +2 -1
  8. pytme-0.3.2.dev0.dist-info/RECORD +136 -0
  9. scripts/estimate_ram_usage.py +97 -0
  10. scripts/match_template.py +213 -196
  11. scripts/match_template_devel.py +1339 -0
  12. scripts/postprocess.py +40 -78
  13. scripts/preprocess.py +4 -5
  14. scripts/preprocessor_gui.py +50 -103
  15. scripts/pytme_runner.py +46 -69
  16. scripts/refine_matches.py +5 -7
  17. tests/preprocessing/test_compose.py +31 -30
  18. tests/preprocessing/test_frequency_filters.py +17 -32
  19. tests/preprocessing/test_preprocessor.py +0 -19
  20. tests/preprocessing/test_utils.py +13 -1
  21. tests/test_analyzer.py +2 -10
  22. tests/test_backends.py +47 -18
  23. tests/test_density.py +72 -13
  24. tests/test_extensions.py +1 -0
  25. tests/test_matching_cli.py +23 -9
  26. tests/test_matching_exhaustive.py +5 -5
  27. tests/test_matching_utils.py +3 -3
  28. tests/test_rotations.py +13 -23
  29. tests/test_structure.py +1 -7
  30. tme/__version__.py +1 -1
  31. tme/analyzer/aggregation.py +47 -16
  32. tme/analyzer/base.py +34 -0
  33. tme/analyzer/peaks.py +26 -13
  34. tme/analyzer/proxy.py +14 -0
  35. tme/backends/_jax_utils.py +124 -71
  36. tme/backends/cupy_backend.py +6 -19
  37. tme/backends/jax_backend.py +110 -105
  38. tme/backends/matching_backend.py +0 -17
  39. tme/backends/mlx_backend.py +0 -29
  40. tme/backends/npfftw_backend.py +100 -97
  41. tme/backends/pytorch_backend.py +65 -78
  42. tme/cli.py +2 -2
  43. tme/density.py +102 -58
  44. tme/extensions.cpython-311-darwin.so +0 -0
  45. tme/filters/_utils.py +52 -24
  46. tme/filters/bandpass.py +99 -105
  47. tme/filters/compose.py +133 -39
  48. tme/filters/ctf.py +51 -102
  49. tme/filters/reconstruction.py +67 -122
  50. tme/filters/wedge.py +296 -325
  51. tme/filters/whitening.py +39 -75
  52. tme/mask.py +2 -2
  53. tme/matching_data.py +87 -15
  54. tme/matching_exhaustive.py +70 -120
  55. tme/matching_optimization.py +9 -63
  56. tme/matching_scores.py +261 -100
  57. tme/matching_utils.py +150 -91
  58. tme/memory.py +1 -0
  59. tme/orientations.py +28 -8
  60. tme/preprocessor.py +0 -239
  61. tme/rotations.py +102 -70
  62. tme/structure.py +601 -631
  63. tme/types.py +1 -0
  64. pytme-0.3.1.post1.dist-info/RECORD +0 -133
  65. {pytme-0.3.1.post1.data → pytme-0.3.2.dev0.data}/scripts/estimate_memory_usage.py +0 -0
  66. {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dev0.dist-info}/WHEEL +0 -0
  67. {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dev0.dist-info}/entry_points.txt +0 -0
  68. {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dev0.dist-info}/licenses/LICENSE +0 -0
  69. {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dev0.dist-info}/top_level.txt +0 -0
tme/structure.py CHANGED
@@ -20,7 +20,7 @@ from .types import NDArray
20
20
  from .rotations import align_to_axis
21
21
  from .preprocessor import atom_profile, Preprocessor
22
22
  from .parser import PDBParser, MMCIFParser, GROParser
23
- from .matching_utils import rigid_transform, minimum_enclosing_box
23
+ from .matching_utils import _rigid_transform
24
24
 
25
25
  __all__ = ["Structure"]
26
26
 
@@ -91,49 +91,35 @@ class Structure:
91
91
 
92
92
  """
93
93
 
94
- #: Array of record types, e.g.ATOM.
94
+ #: Array of record types, e.g., ATOM (n,)
95
95
  record_type: NDArray
96
-
97
- #: Array of serial numbers.
96
+ #: Array of serial numbers (n,)
98
97
  atom_serial_number: NDArray
99
-
100
- #: Array of atom names.
98
+ #: Array of atom names (n,)
101
99
  atom_name: NDArray
102
-
103
- #: Array of x,y,z atom coordinates.
100
+ #: Array of x,y,z atom coordinates (n, d)
104
101
  atom_coordinate: NDArray
105
-
106
- #: Array of alternate location indices.
102
+ #: Array of alternate location indices (n,)
107
103
  alternate_location_indicator: NDArray
108
-
109
- #: Array of residue names.
104
+ #: Array of residue names (n,)
110
105
  residue_name: NDArray
111
-
112
- #: Array of chain identifiers.
106
+ #: Array of chain identifiers (n,)
113
107
  chain_identifier: NDArray
114
-
115
- #: Array of residue ids.
108
+ #: Array of residue ids (n,)
116
109
  residue_sequence_number: NDArray
117
-
118
- #: Array of insertion information.
110
+ #: Array of insertion information (n,)
119
111
  code_for_residue_insertion: NDArray
120
-
121
- #: Array of occupancy factors.
112
+ #: Array of occupancy factors (n,)
122
113
  occupancy: NDArray
123
-
124
- #: Array of B-factors.
114
+ #: Array of B-factors (n,)
125
115
  temperature_factor: NDArray
126
-
127
- #: Array of segment identifiers.
116
+ #: Array of segment identifiers (n,)
128
117
  segment_identifier: NDArray
129
-
130
- #: Array of element symbols.
118
+ #: Array of element symbols (n,)
131
119
  element_symbol: NDArray
132
-
133
- #: Array of charges.
120
+ #: Array of charges (n,)
134
121
  charge: NDArray
135
-
136
- #: Metadata dictionary.
122
+ #: Metadata dictionary
137
123
  metadata: dict
138
124
 
139
125
  def __post_init__(self, *args, **kwargs):
@@ -370,9 +356,9 @@ class Structure:
370
356
  """
371
357
  _, file_extension = splitext(basename(filename.lower()))
372
358
  _formats = {
373
- ".pdb": cls._load_pdb,
374
- ".cif": cls._load_mmcif,
375
- ".gro": cls._load_gro,
359
+ ".pdb": _parse_pdb,
360
+ ".cif": _parse_mmcif,
361
+ ".gro": _parse_gro,
376
362
  }
377
363
  func = _formats.get(file_extension)
378
364
  if func is None:
@@ -382,7 +368,7 @@ class Structure:
382
368
  f"Supported filetypes are {formats}."
383
369
  )
384
370
 
385
- data = func(cls, filename)
371
+ data = func(filename)
386
372
  keep = np.ones(data["element_symbol"].size, dtype=bool)
387
373
  if filter_by_elements:
388
374
  keep = np.logical_and(
@@ -411,220 +397,6 @@ class Structure:
411
397
 
412
398
  return cls(**data)
413
399
 
414
- @staticmethod
415
- def _convert_dtypes(data: Dict[str, List], mapping: Dict):
416
- """
417
- Convert key values in data according to mapping.
418
-
419
- Parameters
420
- ----------
421
- data : Dict
422
- Mapping of keys to list of values
423
- mapping : Dict
424
- Mapping of key in return dict to (key, dtype) in data.
425
-
426
- Returns
427
- -------
428
- dict
429
- Key-value map using key-dtype pairs in mapping on data.
430
- """
431
- out = {}
432
- max_len = max([len(t) for t in data.values() if hasattr(t, "__len__")])
433
-
434
- missing_keys = set()
435
- for out_key, (inner_key, dtype) in mapping.items():
436
- default = "." if dtype is str else 0
437
- if inner_key in data:
438
- continue
439
- missing_keys.add(inner_key)
440
- out[out_key] = np.repeat(default, max_len).astype(dtype)
441
-
442
- if len(missing_keys):
443
- msg = ", ".join([f"'{x}'" for x in missing_keys])
444
- warnings.warn(
445
- f"Missing keys: ({msg}) in data - filling with default value."
446
- )
447
-
448
- for out_key, (inner_key, dtype) in mapping.items():
449
- default = "." if dtype is str else 0
450
-
451
- # Avoid modifying input dictionaries
452
- if inner_key in missing_keys:
453
- continue
454
-
455
- out_data = data[inner_key]
456
- if isinstance(data[inner_key][0], str):
457
- out_data = [str(x).strip() for x in data[inner_key]]
458
-
459
- out_data = np.asarray(out_data)
460
- if dtype is int:
461
- out_data = np.where(out_data == ".", "0", out_data)
462
- elif dtype == "base-36":
463
- dtype = int
464
- base36_offset = int("A0000", 36) - 100000
465
- out_data = np.where(
466
- np.char.isdigit(out_data),
467
- out_data,
468
- np.vectorize(lambda x: int(x, 36) - base36_offset)(out_data),
469
- )
470
- try:
471
- out[out_key] = np.asarray(out_data, dtype=dtype)
472
- except ValueError:
473
- print(
474
- f"Converting {out_key} to {dtype} failed. Setting {out_key} to {default}."
475
- )
476
- out[out_key] = np.repeat(default, max_len).astype(dtype)
477
- return out
478
-
479
- def _load_mmcif(self, filename: str) -> Dict:
480
- """
481
- Parses a macromolecular Crystallographic Information File (mmCIF)
482
- and returns the data in a dictionary format.
483
-
484
- Parameters
485
- ----------
486
- filename : str
487
- The filename of the mmCIF to load.
488
-
489
- Returns
490
- -------
491
- dict
492
- A dictionary of numpy arrays. Keys are the names of the PDB
493
- coordinate section. In addition, some details about the parsed
494
- structure are included. In case of conversion failure, the failing
495
- attribute is set to 0 if its supposed to be an integer value.
496
- """
497
- result = MMCIFParser(filename)
498
-
499
- atom_site_mapping = {
500
- "record_type": ("group_PDB", str),
501
- "atom_serial_number": ("id", int),
502
- "atom_name": ("label_atom_id", str),
503
- "alternate_location_indicator": ("label_alt_id", str),
504
- "residue_name": ("label_comp_id", str),
505
- # "chain_identifier": ("auth_asym_id", str),
506
- "chain_identifier": ("label_asym_id", str),
507
- "residue_sequence_number": ("label_seq_id", int),
508
- "code_for_residue_insertion": ("pdbx_PDB_ins_code", str),
509
- "occupancy": ("occupancy", float),
510
- "temperature_factor": ("B_iso_or_equiv", float),
511
- "segment_identifier": ("label_entity_id", str),
512
- "element_symbol": ("type_symbol", str),
513
- "charge": ("pdbx_formal_charge", str),
514
- }
515
-
516
- out = self._convert_dtypes(result["atom_site"], atom_site_mapping)
517
- number_entries = len(max(out.values(), key=len))
518
- for key, value in out.items():
519
- if value.size != 1:
520
- continue
521
- out[key] = np.repeat(value, number_entries // value.size)
522
-
523
- out["metadata"] = {}
524
- out["atom_coordinate"] = np.transpose(
525
- np.array(
526
- [
527
- result["atom_site"]["Cartn_x"],
528
- result["atom_site"]["Cartn_y"],
529
- result["atom_site"]["Cartn_z"],
530
- ],
531
- dtype=np.float32,
532
- )
533
- )
534
-
535
- detail_mapping = {
536
- "resolution": ("em_3d_reconstruction", "resolution", np.nan),
537
- "resolution_method": ("em_3d_reconstruction", "resolution_method", np.nan),
538
- "method": ("exptl", "method", np.nan),
539
- "electron_source": ("em_imaging", "electron_source", np.nan),
540
- "illumination_mode": ("em_imaging", "illumination_mode", np.nan),
541
- "microscope_model": ("em_imaging", "microscope_model", np.nan),
542
- }
543
- for out_key, (base_key, inner_key, default) in detail_mapping.items():
544
- if base_key not in result:
545
- continue
546
- out["metadata"][out_key] = result[base_key].get(inner_key, default)
547
-
548
- return out
549
-
550
- def _load_pdb(self, filename: str) -> Dict:
551
- """
552
- Parses a Protein Data Bank (PDB) file and returns the data
553
- in a dictionary format.
554
-
555
- Parameters
556
- ----------
557
- filename : str
558
- The filename of the PDB file to load.
559
-
560
- Returns
561
- -------
562
- dict
563
- A dictionary of numpy arrays. Keys are the names of the PDB
564
- coordinate section. In addition, some details about the parsed
565
- structure are included. In case of conversion failure, the failing
566
- attribute is set to 0 if its supposed to be an integer value.
567
- """
568
- result = PDBParser(filename)
569
-
570
- atom_site_mapping = {
571
- "record_type": ("record_type", str),
572
- "atom_serial_number": ("atom_serial_number", "base-36"),
573
- "atom_name": ("atom_name", str),
574
- "alternate_location_indicator": ("alternate_location_indicator", str),
575
- "residue_name": ("residue_name", str),
576
- "chain_identifier": ("chain_identifier", str),
577
- "residue_sequence_number": ("residue_sequence_number", int),
578
- "code_for_residue_insertion": ("code_for_residue_insertion", str),
579
- "occupancy": ("occupancy", float),
580
- "temperature_factor": ("temperature_factor", float),
581
- "segment_identifier": ("segment_identifier", str),
582
- "element_symbol": ("element_symbol", str),
583
- "charge": ("charge", str),
584
- }
585
-
586
- out = self._convert_dtypes(result, atom_site_mapping)
587
-
588
- out["metadata"] = result["details"]
589
- out["atom_coordinate"] = np.array(result["atom_coordinate"], dtype=np.float32)
590
-
591
- return out
592
-
593
- def _load_gro(self, filename):
594
- result = GROParser(filename)
595
-
596
- atom_site_mapping = {
597
- "record_type": ("record_type", str),
598
- "atom_serial_number": ("atom_number", int),
599
- "atom_name": ("atom_name", str),
600
- "alternate_location_indicator": ("label_alt_id", str),
601
- "residue_name": ("residue_name", str),
602
- "chain_identifier": ("segment_identifier", str),
603
- "residue_sequence_number": ("residue_number", int),
604
- "code_for_residue_insertion": ("pdbx_PDB_ins_code", str),
605
- "occupancy": ("occupancy", float),
606
- "temperature_factor": ("B_iso_or_equiv", float),
607
- "segment_identifier": ("segment_identifier", str),
608
- "element_symbol": ("type_symbol", str),
609
- "charge": ("pdbx_formal_charge", str),
610
- }
611
-
612
- out = self._convert_dtypes(result, atom_site_mapping)
613
-
614
- unique_chains = np.unique(out["segment_identifier"])
615
- if len(unique_chains) > 1:
616
- warnings.warn(
617
- "Multiple GRO files detected - treating them as a single Structure. "
618
- "GRO file number is given by segment_identifier according to the "
619
- "input file. Note: You need to subset the Structure to operate on "
620
- "individual GRO files."
621
- )
622
-
623
- mkeys = ("title", "time", "box_vectors")
624
- out["metadata"] = {key: result.get(key) for key in mkeys}
625
- out["atom_coordinate"] = np.asarray(result["atom_coordinate"], dtype=np.float32)
626
- return out
627
-
628
400
  def to_file(self, filename: str) -> None:
629
401
  """
630
402
  Writes the :py:class:`Structure` instance to disk.
@@ -656,220 +428,25 @@ class Structure:
656
428
  >>> structure = Structure.from_file(filename=fname)
657
429
  >>> structure.to_file(f"{oname}.cif") # Writes an mmCIF file to disk
658
430
  >>> structure.to_file(f"{oname}.pdb") # Writes a PDB file to disk
659
-
660
431
  """
661
432
  _, file_extension = splitext(basename(filename.lower()))
662
433
  _formats = {
663
- ".pdb": self._write_pdb,
664
- ".cif": self._write_mmcif,
665
- ".gro": self._write_gro,
434
+ ".pdb": _to_pdb,
435
+ ".cif": _to_mmcif,
436
+ ".gro": _to_gro,
666
437
  }
667
438
  func = _formats.get(file_extension)
668
439
  if func is None:
669
440
  formats = ",".join([f"'{x}'" for x in _formats.keys()])
670
441
  raise NotImplementedError(
671
- f"Files with extension {file_extension} are not supported. "
672
- f"Supported filetypes are {formats}."
442
+ f"Supported filetypes are {formats} - got {file_extension}."
673
443
  )
674
444
 
675
- if np.any(np.vectorize(len)(self.chain_identifier) > 2):
445
+ if np.any(np.vectorize(len)(self.chain_identifier) > 2) and func == _to_pdb:
676
446
  warnings.warn("Chain identifiers longer than one will be shortened.")
677
447
 
678
- if self.atom_coordinate.shape[0] > 10**5 and func == self._write_pdb:
679
- warnings.warn(
680
- "The structure contains more than 100,000 atoms. Consider using mmcif."
681
- )
682
-
683
448
  with open(filename, mode="w", encoding="utf-8") as ofile:
684
- ofile.write(func())
685
-
686
- def _write_pdb(self) -> str:
687
- """
688
- Returns a PDB string representation of the structure instance.
689
-
690
- Returns
691
- -------
692
- str
693
- String containing PDB file coordine lines.
694
- """
695
- data_out = []
696
- for index in range(self.atom_coordinate.shape[0]):
697
- x, y, z = self.atom_coordinate[index, :]
698
- line = list(" " * 80)
699
- line[0:6] = f"{self.record_type[index]:<6}"
700
- line[6:11] = f"{self.atom_serial_number[index]:>5}"
701
- line[12:16] = f"{self.atom_name[index]:<4}"
702
- line[16] = f"{self.alternate_location_indicator[index]:<1}"
703
- line[17:20] = f"{self.residue_name[index]:<3}"
704
- line[21] = f"{self.chain_identifier[index][0]:<1}"
705
- line[22:26] = f"{self.residue_sequence_number[index]:>4}"
706
- line[26] = f"{self.code_for_residue_insertion[index]:<1}"
707
- line[30:38] = f"{x:>8.3f}"
708
- line[38:46] = f"{y:>8.3f}"
709
- line[46:54] = f"{z:>8.3f}"
710
- line[54:60] = f"{self.occupancy[index]:>6.2f}"
711
- line[60:66] = f"{self.temperature_factor[index]:>6.2f}"
712
- line[72:76] = f"{self.segment_identifier[index]:>4}"
713
- line[76:78] = f"{self.element_symbol[index]:<2}"
714
- line[78:80] = f"{self.charge[index]:>2}"
715
- data_out.append("".join(line))
716
- data_out.append("END")
717
- data_out = "\n".join(data_out)
718
- return data_out
719
-
720
- def _write_mmcif(self) -> str:
721
- """
722
- Returns a MMCIF string representation of the structure instance.
723
-
724
- Returns
725
- -------
726
- str
727
- String containing MMCIF file coordinate lines.
728
- """
729
- model_num, entity_id = 1, 1
730
- data = {
731
- "group_PDB": [],
732
- "id": [],
733
- "type_symbol": [],
734
- "label_atom_id": [],
735
- "label_alt_id": [],
736
- "label_comp_id": [],
737
- "label_asym_id": [],
738
- "label_entity_id": [],
739
- "label_seq_id": [],
740
- "pdbx_PDB_ins_code": [],
741
- "Cartn_x": [],
742
- "Cartn_y": [],
743
- "Cartn_z": [],
744
- "occupancy": [],
745
- "B_iso_or_equiv": [],
746
- "pdbx_formal_charge": [],
747
- "auth_seq_id": [],
748
- "auth_comp_id": [],
749
- "auth_asym_id": [],
750
- "auth_atom_id": [],
751
- "pdbx_PDB_model_num": [],
752
- }
753
-
754
- for index in range(self.atom_coordinate.shape[0]):
755
- x, y, z = self.atom_coordinate[index, :]
756
- data["group_PDB"].append(self.record_type[index])
757
- data["id"].append(str(self.atom_serial_number[index]))
758
- data["type_symbol"].append(self.element_symbol[index])
759
- data["label_atom_id"].append(self.atom_name[index])
760
- data["label_alt_id"].append(self.alternate_location_indicator[index])
761
- data["label_comp_id"].append(self.residue_name[index])
762
- data["label_asym_id"].append(self.chain_identifier[index][0])
763
- data["label_entity_id"].append(str(entity_id))
764
- data["label_seq_id"].append(str(self.residue_sequence_number[index]))
765
- data["pdbx_PDB_ins_code"].append(self.code_for_residue_insertion[index])
766
- data["Cartn_x"].append(f"{x:.3f}")
767
- data["Cartn_y"].append(f"{y:.3f}")
768
- data["Cartn_z"].append(f"{z:.3f}")
769
- data["occupancy"].append(f"{self.occupancy[index]:.2f}")
770
- data["B_iso_or_equiv"].append(f"{self.temperature_factor[index]:.2f}")
771
- data["pdbx_formal_charge"].append(self.charge[index])
772
- data["auth_seq_id"].append(str(self.residue_sequence_number[index]))
773
- data["auth_comp_id"].append(self.residue_name[index])
774
- data["auth_asym_id"].append(self.chain_identifier[index][0])
775
- data["auth_atom_id"].append(self.atom_name[index])
776
- data["pdbx_PDB_model_num"].append(str(model_num))
777
-
778
- output_data = {"atom_site": data}
779
- original_file = self.metadata.get("filepath", "")
780
- try:
781
- new_data = {k: v for k, v in MMCIFParser(original_file).items()}
782
- index = self.atom_serial_number - 1
783
- new_data["atom_site"] = {
784
- k: [v[i] for i in index] for k, v in new_data["atom_site"].items()
785
- }
786
- new_data["atom_site"]["Cartn_x"] = data["Cartn_x"]
787
- new_data["atom_site"]["Cartn_y"] = data["Cartn_y"]
788
- new_data["atom_site"]["Cartn_z"] = data["Cartn_z"]
789
- output_data = new_data
790
- except Exception:
791
- pass
792
-
793
- ret = ""
794
- for category, subdict in output_data.items():
795
- if not len(subdict):
796
- continue
797
-
798
- ret += "#\n"
799
- is_loop = isinstance(subdict[list(subdict.keys())[0]], list)
800
- if not is_loop:
801
- for k in subdict:
802
- ret += f"_{category}.{k}\t{subdict[k]}\n"
803
- else:
804
- ret += "loop_\n"
805
- ret += "".join([f"_{category}.{k}\n" for k in subdict])
806
-
807
- subdict = {
808
- k: [_format_string(s) for s in v] for k, v in subdict.items()
809
- }
810
- key_length = {
811
- key: len(max(value, key=lambda x: len(x), default=""))
812
- for key, value in subdict.items()
813
- }
814
- padded_subdict = {
815
- key: [s.ljust(key_length[key] + 1) for s in values]
816
- for key, values in subdict.items()
817
- }
818
-
819
- data = [
820
- "".join([str(x) for x in content])
821
- for content in zip(*padded_subdict.values())
822
- ]
823
- ret += "\n".join([entry for entry in data]) + "\n"
824
-
825
- return ret
826
-
827
- def _write_gro(self) -> str:
828
- """
829
- Generate a GRO format string representation of the structure.
830
-
831
- Returns
832
- -------
833
- str
834
- String representation of the structure in GRO format.
835
- """
836
- ret = ""
837
- gro_files = np.unique(self.segment_identifier)
838
- for index, gro_file in enumerate(gro_files):
839
- subset = self[self.segment_identifier == gro_file]
840
-
841
- title = self.metadata.get("title", "Missing title")
842
- box_vectors = self.metadata.get("box_vectors")
843
- try:
844
- title = title[index]
845
- box_vectors = box_vectors[index]
846
- except Exception:
847
- pass
848
-
849
- if box_vectors is None:
850
- box_vectors = [0.0, 0.0, 0.0]
851
-
852
- num_atoms = subset.atom_coordinate.shape[0]
853
- lines = [title, f"{num_atoms}"]
854
- for i in range(num_atoms):
855
- res_num = subset.residue_sequence_number[i]
856
- res_name = subset.residue_name[i]
857
- atom_name = subset.atom_name[i]
858
- atom_num = subset.atom_serial_number[i]
859
-
860
- x, y, z = subset.atom_coordinate[i]
861
- coord = f"{atom_num % 100000:5d}{x:8.3f}{y:8.3f}{z:8.3f}"
862
- line = f"{res_num % 100000:5d}{res_name:5s}{atom_name:5s}{coord}"
863
-
864
- if "velocity" in subset.metadata:
865
- vx, vy, vz = subset.metadata["velocity"][i]
866
- line += f"{vx:8.4f}{vy:8.4f}{vz:8.4f}"
867
-
868
- lines.append(line)
869
-
870
- lines.append(" ".join(f"{v:.5f}" for v in box_vectors))
871
- ret += "\n".join(lines) + "\n"
872
- return ret
449
+ _ = ofile.write(func(self))
873
450
 
874
451
  def subset_by_chain(self, chain: str = None) -> "Structure":
875
452
  """
@@ -976,18 +553,7 @@ class Structure:
976
553
  >>> structure.center_of_mass()
977
554
  array([-0.89391639, 29.94908928, -2.64736741])
978
555
  """
979
- atoms = self.element_symbol
980
- match weight_type:
981
- case "atomic_weight":
982
- weights = [self._elements[atom].atomic_weight for atom in atoms]
983
- case "atomic_number":
984
- weights = [self._elements[atom].atomic_number for atom in atoms]
985
- case "equal":
986
- weights = np.ones((len(atoms)))
987
- case _:
988
- raise NotImplementedError(
989
- "weight_type can be 'atomic_weight', 'atomic_number' or 'equal."
990
- )
556
+ weights = self._get_atom_weights(self.element_symbol, weight_type)
991
557
  return np.dot(self.atom_coordinate.T, weights) / np.sum(weights)
992
558
 
993
559
  def rigid_transform(
@@ -995,6 +561,7 @@ class Structure:
995
561
  rotation_matrix: NDArray = None,
996
562
  translation: NDArray = None,
997
563
  use_geometric_center: bool = False,
564
+ center: NDArray = None,
998
565
  ) -> "Structure":
999
566
  """
1000
567
  Performs a rigid transform of internal structure coordinates.
@@ -1005,8 +572,14 @@ class Structure:
1005
572
  The rotation matrix to apply to the coordinates, defaults to identity.
1006
573
  translation : NDArray, optional
1007
574
  The vector to translate the coordinates by, defaults to 0.
575
+ center : NDArray, optional
576
+ Rotation center.
1008
577
  use_geometric_center : bool, optional
1009
- Whether to use geometric or coordinate center.
578
+ Whether to use geometric or mass center.
579
+
580
+ .. deprecated:: 0.3.2
581
+
582
+ All rotations are w.r.t to the center of mass.
1010
583
 
1011
584
  Returns
1012
585
  -------
@@ -1025,116 +598,30 @@ class Structure:
1025
598
  >>> translation = (0, 1, -5)
1026
599
  >>> )
1027
600
  """
1028
- out = np.empty_like(self.atom_coordinate.T)
601
+ ndim = self.atom_coordinate.shape[1]
1029
602
  if translation is None:
1030
- translation = np.zeros((self.atom_coordinate.shape[1]))
603
+ translation = np.zeros((ndim))
1031
604
  if rotation_matrix is None:
1032
- rotation_matrix = np.eye(self.atom_coordinate.shape[1])
605
+ rotation_matrix = np.eye(ndim)
606
+
607
+ # Assume we discretize the structure on a grid
608
+ min_coordinate = self.atom_coordinate.min(axis=0)
609
+ center = np.divide(self.atom_coordinate.max(axis=0) - min_coordinate, 2)
610
+ center = np.add(center, min_coordinate)
611
+
612
+ if not use_geometric_center:
613
+ center = self.center_of_mass()
1033
614
 
1034
- rigid_transform(
615
+ ret = self.copy()
616
+ _rigid_transform(
1035
617
  coordinates=self.atom_coordinate.T,
1036
618
  rotation_matrix=rotation_matrix,
1037
619
  translation=translation,
1038
- out=out,
1039
- use_geometric_center=use_geometric_center,
620
+ out=ret.atom_coordinate.T,
621
+ center=center,
1040
622
  )
1041
- ret = self.copy()
1042
- ret.atom_coordinate = out.T.copy()
1043
623
  return ret
1044
624
 
1045
- def centered(self) -> Tuple["Structure", NDArray]:
1046
- """
1047
- Shifts the structure analogous to :py:meth:`tme.density.Density.centered`.
1048
-
1049
- Returns
1050
- -------
1051
- Structure
1052
- A copy of the class instance whose data center of mass is in the
1053
- center of the data array.
1054
- NDArray
1055
- The coordinate translation.
1056
-
1057
- See Also
1058
- --------
1059
- :py:meth:`tme.density.Density.centered`
1060
-
1061
- Examples
1062
- --------
1063
- >>> from importlib_resources import files
1064
- >>> from tme import Structure
1065
- >>> fname = str(files("tests.data").joinpath("Structures/5khe.cif"))
1066
- >>> structure = Structure.from_file(filename=fname)
1067
- >>> centered_structure, translation = structure.centered()
1068
- >>> translation
1069
- array([34.89391639, 4.05091072, 36.64736741])
1070
- """
1071
- center_of_mass = self.center_of_mass()
1072
- enclosing_box = minimum_enclosing_box(coordinates=self.atom_coordinate.T)
1073
- shift = np.subtract(np.divide(enclosing_box, 2), center_of_mass)
1074
-
1075
- transformed_structure = self.rigid_transform(
1076
- translation=shift, rotation_matrix=np.eye(shift.size)
1077
- )
1078
-
1079
- return transformed_structure, shift
1080
-
1081
- def _coordinate_to_position(
1082
- self,
1083
- shape: Tuple[int],
1084
- sampling_rate: Tuple[float],
1085
- origin: Tuple[float],
1086
- ) -> (NDArray, Tuple[str], Tuple[int], float, Tuple[float]):
1087
- """
1088
- Converts coordinates to positions.
1089
-
1090
- Parameters
1091
- ----------
1092
- shape : Tuple[int,]
1093
- The desired shape of the output array.
1094
- sampling_rate : float
1095
- The sampling rate of the output array in unit of self.atom_coordinate.
1096
- origin : Tuple[float,]
1097
- The origin of the coordinate system.
1098
-
1099
- Returns
1100
- -------
1101
- Tuple[NDArray, List[str], Tuple[int, ], float, Tuple[float,]]
1102
- Returns positions, atom_types, shape, sampling_rate, and origin.
1103
- """
1104
- coordinates = self.atom_coordinate.copy()
1105
- atom_types = self.element_symbol.copy()
1106
-
1107
- coordinates = coordinates
1108
- sampling_rate = 1 if sampling_rate is None else sampling_rate
1109
- adjust_origin = origin is not None and shape is None
1110
- origin = coordinates.min(axis=0) if origin is None else origin
1111
- positions = (coordinates - origin) / sampling_rate
1112
- positions = np.rint(positions).astype(int)
1113
-
1114
- if adjust_origin:
1115
- left_shift = positions.min(axis=0)
1116
- positions -= left_shift
1117
- shape = positions.max(axis=0) + 1
1118
- origin = origin + np.multiply(left_shift, sampling_rate)
1119
-
1120
- if shape is None:
1121
- shape = positions.max(axis=0) + 1
1122
-
1123
- valid_positions = np.sum(
1124
- np.logical_and(positions < shape, positions >= 0), axis=1
1125
- )
1126
-
1127
- positions = positions[valid_positions == positions.shape[1], :]
1128
- atom_types = atom_types[valid_positions == positions.shape[1]]
1129
-
1130
- self.metadata["nAtoms_outOfBound"] = 0
1131
- if positions.shape[0] != coordinates.shape[0]:
1132
- out_of_bounds = coordinates.shape[0] - positions.shape[0]
1133
- print(f"{out_of_bounds}/{coordinates.shape[0]} atoms were out of bounds.")
1134
- self.metadata["nAtoms_outOfBound"] = out_of_bounds
1135
-
1136
- return positions, atom_types, shape, sampling_rate, origin
1137
-
1138
625
  def _position_to_vdw_sphere(
1139
626
  self,
1140
627
  positions: Tuple[float],
@@ -1339,10 +826,9 @@ class Structure:
1339
826
  atoms : Tuple of strings, optional
1340
827
  The atoms to get the weights for. If None, weights for all atoms
1341
828
  are used. Default is None.
1342
-
1343
829
  weight_type : str, optional
1344
830
  The type of weights to return. This can either be 'atomic_weight',
1345
- 'atomic_number', or 'van_der_waals_radius'. Default is 'atomic_weight'.
831
+ 'atomic_number', or 'equal'. Default is 'atomic_weight'.
1346
832
 
1347
833
  Returns
1348
834
  -------
@@ -1355,9 +841,11 @@ class Structure:
1355
841
  weight = [self._elements[atom].atomic_weight for atom in atoms]
1356
842
  case "atomic_number":
1357
843
  weight = [self._elements[atom].atomic_number for atom in atoms]
844
+ case "equal":
845
+ weight = np.ones((len(atoms)))
1358
846
  case _:
1359
847
  raise NotImplementedError(
1360
- "weight_type can either be 'atomic_weight' or 'atomic_number'"
848
+ "weight_type can be 'atomic_weight', 'atomic_number' or 'equal'."
1361
849
  )
1362
850
  return weight
1363
851
 
@@ -1457,15 +945,60 @@ class Structure:
1457
945
  )
1458
946
 
1459
947
  temp = self.subset_by_chain(chain=chain)
1460
- positions, atoms, _shape, sampling_rate, origin = temp._coordinate_to_position(
1461
- shape=shape, sampling_rate=sampling_rate, origin=origin
948
+ positions, valid, _shape, origin = _coordinate_to_position(
949
+ coordinates=temp.atom_coordinate,
950
+ shape=shape,
951
+ sampling_rate=sampling_rate,
952
+ origin=origin,
1462
953
  )
954
+ positions = positions[valid]
955
+ atoms = temp.element_symbol[valid]
1463
956
  volume = np.zeros(_shape, dtype=np.float32)
1464
957
  if weight_type in ("atomic_weight", "atomic_number"):
1465
- weights = temp._get_atom_weights(atoms=atoms, weight_type=weight_type)
1466
- np.add.at(volume, tuple(positions.T), weights)
958
+ weights = np.array(
959
+ temp._get_atom_weights(atoms=atoms, weight_type=weight_type)
960
+ )
961
+ p0 = np.floor(positions).astype(int)
962
+
963
+ x0, y0, z0 = p0.T
964
+ x1, y1, z1 = (p0 + 1).T
965
+ dx, dy, dz = (positions - p0).T
966
+
967
+ w000 = (1 - dx) * (1 - dy) * (1 - dz)
968
+ w001 = (1 - dx) * (1 - dy) * dz
969
+ w010 = (1 - dx) * dy * (1 - dz)
970
+ w011 = (1 - dx) * dy * dz
971
+ w100 = dx * (1 - dy) * (1 - dz)
972
+ w101 = dx * (1 - dy) * dz
973
+ w110 = dx * dy * (1 - dz)
974
+ w111 = dx * dy * dz
975
+
976
+ corners = [
977
+ ((x0, y0, z0), w000),
978
+ ((x0, y0, z1), w001),
979
+ ((x0, y1, z0), w010),
980
+ ((x0, y1, z1), w011),
981
+ ((x1, y0, z0), w100),
982
+ ((x1, y0, z1), w101),
983
+ ((x1, y1, z0), w110),
984
+ ((x1, y1, z1), w111),
985
+ ]
986
+
987
+ for positions, tril_weights in corners:
988
+ positions = np.array(positions).T
989
+ keep = np.all(
990
+ np.logical_and(positions < _shape, positions >= 0), axis=1
991
+ )
992
+
993
+ # Safeguard, but this should not happen using _coordinate_to_position
994
+ positions = positions[keep]
995
+ _weights = (tril_weights * weights)[keep]
996
+ np.add.at(volume, tuple(positions.T), _weights)
997
+
1467
998
  elif weight_type == "van_der_waals_radius":
1468
- self._position_to_vdw_sphere(positions, atoms, sampling_rate, volume)
999
+ self._position_to_vdw_sphere(
1000
+ np.rint(positions).astype(int), atoms, sampling_rate, volume
1001
+ )
1469
1002
  elif weight_type == "scattering_factors":
1470
1003
  self._position_to_scattering_factors(
1471
1004
  positions,
@@ -1493,8 +1026,6 @@ class Structure:
1493
1026
  sampling_rate=sampling_rate,
1494
1027
  **weight_type_args,
1495
1028
  )
1496
-
1497
- self.metadata.update(temp.metadata)
1498
1029
  return volume, origin, sampling_rate
1499
1030
 
1500
1031
  @classmethod
@@ -1502,9 +1033,9 @@ class Structure:
1502
1033
  cls,
1503
1034
  structure1: "Structure",
1504
1035
  structure2: "Structure",
1505
- origin: NDArray = None,
1506
1036
  sampling_rate: float = None,
1507
1037
  weighted: bool = False,
1038
+ **kwargs,
1508
1039
  ) -> float:
1509
1040
  """
1510
1041
  Compute root mean square deviation (RMSD) between two structures with the
@@ -1514,8 +1045,6 @@ class Structure:
1514
1045
  ----------
1515
1046
  structure1, structure2 : :py:class:`Structure`
1516
1047
  Structure instances to compare.
1517
- origin : tuple of floats, optional
1518
- Coordinate system origin. For computing RMSD on discretized grids.
1519
1048
  sampling_rate : tuple of floats, optional
1520
1049
  Sampling rate in units of :py:attr:`atom_coordinate`.
1521
1050
  For computing RMSD on discretized grids.
@@ -1543,38 +1072,29 @@ class Structure:
1543
1072
  >>> Structure.compare_structures(structure, structure)
1544
1073
  0.0
1545
1074
  """
1546
- if origin is None:
1547
- origin = np.zeros(structure1.atom_coordinate.shape[1])
1548
-
1549
1075
  coordinates1 = structure1.atom_coordinate
1550
1076
  coordinates2 = structure2.atom_coordinate
1551
1077
  atoms1, atoms2 = structure1.element_symbol, structure2.element_symbol
1552
- if sampling_rate is not None:
1553
- coordinates1 = np.rint(
1554
- np.divide(np.subtract(coordinates1, origin), sampling_rate)
1555
- ).astype(int)
1556
- coordinates2 = np.rint(
1557
- np.divide(np.subtract(coordinates2, origin), sampling_rate)
1558
- ).astype(int)
1559
-
1560
- weights1 = np.ones_like(structure1.atom_coordinate.shape[0])
1561
- weights2 = np.ones_like(structure2.atom_coordinate.shape[0])
1562
- if weighted:
1563
- weights1 = np.array(structure1._get_atom_weights(atoms=atoms1))
1564
- weights2 = np.array(structure2._get_atom_weights(atoms=atoms2))
1565
-
1566
1078
  if not np.allclose(coordinates1.shape, coordinates2.shape):
1567
1079
  raise ValueError(
1568
1080
  "Input structures need to have the same number of coordinates."
1569
1081
  )
1570
- if not np.allclose(weights1.shape, weights2.shape):
1082
+ if not np.allclose(atoms1.shape, atoms2.shape):
1571
1083
  raise ValueError("Input structures need to have the same number of atoms.")
1572
1084
 
1085
+ if sampling_rate is not None:
1086
+ coordinates1 = np.divide(coordinates1, sampling_rate).astype(int)
1087
+ coordinates2 = np.divide(coordinates2, sampling_rate).astype(int)
1088
+
1089
+ weights1 = np.ones(coordinates1.shape[0])
1090
+ weights2 = np.ones(coordinates2.shape[0])
1091
+ if weighted:
1092
+ weights1 = np.array(structure1._get_atom_weights(atoms=atoms1))
1093
+ weights2 = np.array(structure2._get_atom_weights(atoms=atoms2))
1094
+
1573
1095
  squared_diff = np.sum(np.square(coordinates1 - coordinates2), axis=1)
1574
1096
  weighted_quared_diff = squared_diff * ((weights1 + weights2) / 2)
1575
- rmsd = np.sqrt(np.mean(weighted_quared_diff))
1576
-
1577
- return rmsd
1097
+ return np.sqrt(np.mean(weighted_quared_diff))
1578
1098
 
1579
1099
  @classmethod
1580
1100
  def align_structures(
@@ -1595,9 +1115,6 @@ class Structure:
1595
1115
  Structure instances to align.
1596
1116
  origin : tuple of floats, optional
1597
1117
  Coordinate system origin. For computing RMSD on discretized grids.
1598
- sampling_rate : tuple of floats, optional
1599
- Sampling rate in units of :py:attr:`atom_coordinate`.
1600
- For computing RMSD on discretized grids.
1601
1118
  weighted : bool, optional
1602
1119
  Whether atoms should be weighted by their atomic weight.
1603
1120
 
@@ -1622,29 +1139,10 @@ class Structure:
1622
1139
  >>> aligned, rmsd = Structure.align_structures(structure, transformed)
1623
1140
  Initial RMSD: 31.07189 - Final RMSD: 0.00000
1624
1141
  """
1625
- if origin is None:
1626
- origin = np.minimum(
1627
- structure1.atom_coordinate.min(axis=0),
1628
- structure2.atom_coordinate.min(axis=0),
1629
- ).astype(int)
1630
-
1631
- initial_rmsd = cls.compare_structures(
1632
- structure1=structure1,
1633
- structure2=structure2,
1634
- origin=origin,
1635
- sampling_rate=sampling_rate,
1636
- weighted=weighted,
1637
- )
1142
+ rmsd = cls.compare_structures(structure1, structure2, weighted=weighted)
1638
1143
 
1639
1144
  reference = structure1.atom_coordinate.copy()
1640
1145
  query = structure2.atom_coordinate.copy()
1641
- if sampling_rate is not None:
1642
- reference, atoms1, shape, _, _ = structure1._coordinate_to_position(
1643
- shape=None, sampling_rate=sampling_rate, origin=origin
1644
- )
1645
- query, atoms2, shape, _, _ = structure2._coordinate_to_position(
1646
- shape=None, sampling_rate=sampling_rate, origin=origin
1647
- )
1648
1146
 
1649
1147
  reference_mean = reference.mean(axis=0)
1650
1148
  query_mean = query.mean(axis=0)
@@ -1667,16 +1165,8 @@ class Structure:
1667
1165
  ret = structure2.copy()
1668
1166
  ret.atom_coordinate = np.dot(query + query_mean, rotation) + translation
1669
1167
 
1670
- final_rmsd = cls.compare_structures(
1671
- structure1=temp,
1672
- structure2=ret,
1673
- origin=origin,
1674
- sampling_rate=None,
1675
- weighted=weighted,
1676
- )
1677
-
1678
- print(f"Initial RMSD: {initial_rmsd:.5f} - Final RMSD: {final_rmsd:.5f}")
1679
-
1168
+ final_rmsd = cls.compare_structures(temp, ret, weighted=weighted)
1169
+ print(f"Initial RMSD: {rmsd:.5f} - Final RMSD: {final_rmsd:.5f}")
1680
1170
  return ret, final_rmsd
1681
1171
 
1682
1172
  def align_to_axis(
@@ -1684,10 +1174,490 @@ class Structure:
1684
1174
  ):
1685
1175
  if coordinates is None:
1686
1176
  coordinates = self.atom_coordinate
1687
-
1688
1177
  return align_to_axis(coordinates, axis=axis, flip=flip, **kwargs)
1689
1178
 
1690
1179
 
1180
+ def _coordinate_to_position(
1181
+ coordinates: NDArray,
1182
+ shape: Tuple[int],
1183
+ sampling_rate: Tuple[float, ...],
1184
+ origin: Tuple[float, ...],
1185
+ ) -> (NDArray, Tuple[int], Tuple[float]):
1186
+ """
1187
+ Converts coordinates to positions on a grid.
1188
+
1189
+ Parameters
1190
+ ----------
1191
+ shape : Tuple[int,]
1192
+ The desired shape of the grid.
1193
+ sampling_rate : float
1194
+ The sampling rate of the grid in unit of self.atom_coordinate.
1195
+ origin : Tuple[float,]
1196
+ The origin of the coordinate system.
1197
+
1198
+ Returns
1199
+ -------
1200
+ Tuple[NDArray, NDArray Tuple[int, ...] Tuple[float,...]]
1201
+ Returns positions, valid grid positions, shape, and origin.
1202
+ """
1203
+ sampling_rate = 1 if sampling_rate is None else sampling_rate
1204
+
1205
+ adjust_origin = origin is not None and shape is None
1206
+ origin = coordinates.min(axis=0) if origin is None else origin
1207
+ positions = (coordinates - origin) / sampling_rate
1208
+
1209
+ pad = 1
1210
+ # 0.3.2 switched from rint to ceil to accomodate interpolation scheme
1211
+ if adjust_origin:
1212
+ left_shift = positions.min(axis=0)
1213
+ positions -= left_shift
1214
+ origin = origin + np.multiply(left_shift, sampling_rate)
1215
+
1216
+ if shape is None:
1217
+ shape = np.ceil(positions.max(axis=0)) + pad
1218
+
1219
+ valid_positions = (
1220
+ np.sum(np.logical_and(positions < shape, positions >= 0), axis=1)
1221
+ == positions.shape[1]
1222
+ )
1223
+
1224
+ n_mapped = valid_positions.sum()
1225
+ if n_mapped != coordinates.shape[0]:
1226
+ out_of_bounds = coordinates.shape[0] - n_mapped
1227
+ warnings.warn(
1228
+ f"{out_of_bounds}/{coordinates.shape[0]} atoms were out of bounds."
1229
+ )
1230
+
1231
+ shape = tuple(int(x) for x in shape)
1232
+ origin = tuple(float(x) for x in origin)
1233
+ return positions, np.where(valid_positions)[0], shape, origin
1234
+
1235
+
1236
+ def _convert_dtypes(data: Dict[str, List], mapping: Dict):
1237
+ """
1238
+ Convert key values in data according to mapping.
1239
+
1240
+ Parameters
1241
+ ----------
1242
+ data : Dict
1243
+ Mapping of keys to list of values
1244
+ mapping : Dict
1245
+ Mapping of key in return dict to (key, dtype) in data.
1246
+
1247
+ Returns
1248
+ -------
1249
+ dict
1250
+ Key-value map using key-dtype pairs in mapping on data.
1251
+ """
1252
+ out = {}
1253
+ max_len = max([len(t) for t in data.values() if hasattr(t, "__len__")])
1254
+
1255
+ missing_keys = set()
1256
+ for out_key, (inner_key, dtype) in mapping.items():
1257
+ default = "." if dtype is str else 0
1258
+ if inner_key in data:
1259
+ continue
1260
+ missing_keys.add(inner_key)
1261
+ out[out_key] = np.repeat(default, max_len).astype(dtype)
1262
+
1263
+ if len(missing_keys):
1264
+ msg = ", ".join([f"'{x}'" for x in missing_keys])
1265
+ warnings.warn(f"Missing keys: ({msg}) in data - filling with default value.")
1266
+
1267
+ for out_key, (inner_key, dtype) in mapping.items():
1268
+ default = "." if dtype is str else 0
1269
+
1270
+ # Avoid modifying input dictionaries
1271
+ if inner_key in missing_keys:
1272
+ continue
1273
+
1274
+ out_data = data[inner_key]
1275
+ if isinstance(data[inner_key][0], str):
1276
+ out_data = [str(x).strip() for x in data[inner_key]]
1277
+
1278
+ out_data = np.asarray(out_data)
1279
+ if dtype is int:
1280
+ out_data = np.where(out_data == ".", "0", out_data)
1281
+ elif dtype == "base-36":
1282
+ dtype = int
1283
+ base36_offset = int("A0000", 36) - 100000
1284
+ out_data = np.where(
1285
+ np.char.isdigit(out_data),
1286
+ out_data,
1287
+ np.vectorize(lambda x: int(x, 36) - base36_offset)(out_data),
1288
+ )
1289
+ try:
1290
+ out[out_key] = np.asarray(out_data, dtype=dtype)
1291
+ except ValueError:
1292
+ print(
1293
+ f"Converting {out_key} to {dtype} failed. Setting {out_key} to {default}."
1294
+ )
1295
+ out[out_key] = np.repeat(default, max_len).astype(dtype)
1296
+ return out
1297
+
1298
+
1299
+ def _parse_mmcif(filename: str) -> Dict:
1300
+ """
1301
+ Parses a macromolecular Crystallographic Information File (mmCIF)
1302
+ and returns the data in a dictionary format.
1303
+
1304
+ Parameters
1305
+ ----------
1306
+ filename : str
1307
+ The filename of the mmCIF to load.
1308
+
1309
+ Returns
1310
+ -------
1311
+ dict
1312
+ A dictionary of numpy arrays. Keys are the names of the PDB
1313
+ coordinate section. In addition, some details about the parsed
1314
+ structure are included. In case of conversion failure, the failing
1315
+ attribute is set to 0 if its supposed to be an integer value.
1316
+ """
1317
+ result = MMCIFParser(filename)
1318
+
1319
+ atom_site_mapping = {
1320
+ "record_type": ("group_PDB", str),
1321
+ "atom_serial_number": ("id", int),
1322
+ "atom_name": ("label_atom_id", str),
1323
+ "alternate_location_indicator": ("label_alt_id", str),
1324
+ "residue_name": ("label_comp_id", str),
1325
+ # "chain_identifier": ("auth_asym_id", str),
1326
+ "chain_identifier": ("label_asym_id", str),
1327
+ "residue_sequence_number": ("label_seq_id", int),
1328
+ "code_for_residue_insertion": ("pdbx_PDB_ins_code", str),
1329
+ "occupancy": ("occupancy", float),
1330
+ "temperature_factor": ("B_iso_or_equiv", float),
1331
+ "segment_identifier": ("label_entity_id", str),
1332
+ "element_symbol": ("type_symbol", str),
1333
+ "charge": ("pdbx_formal_charge", str),
1334
+ }
1335
+
1336
+ out = _convert_dtypes(result["atom_site"], atom_site_mapping)
1337
+ number_entries = len(max(out.values(), key=len))
1338
+ for key, value in out.items():
1339
+ if value.size != 1:
1340
+ continue
1341
+ out[key] = np.repeat(value, number_entries // value.size)
1342
+
1343
+ out["metadata"] = {}
1344
+ out["atom_coordinate"] = np.transpose(
1345
+ np.array(
1346
+ [
1347
+ result["atom_site"]["Cartn_x"],
1348
+ result["atom_site"]["Cartn_y"],
1349
+ result["atom_site"]["Cartn_z"],
1350
+ ],
1351
+ dtype=np.float32,
1352
+ )
1353
+ )
1354
+
1355
+ detail_mapping = {
1356
+ "resolution": ("em_3d_reconstruction", "resolution", np.nan),
1357
+ "resolution_method": ("em_3d_reconstruction", "resolution_method", np.nan),
1358
+ "method": ("exptl", "method", np.nan),
1359
+ "electron_source": ("em_imaging", "electron_source", np.nan),
1360
+ "illumination_mode": ("em_imaging", "illumination_mode", np.nan),
1361
+ "microscope_model": ("em_imaging", "microscope_model", np.nan),
1362
+ }
1363
+ for out_key, (base_key, inner_key, default) in detail_mapping.items():
1364
+ if base_key not in result:
1365
+ continue
1366
+ out["metadata"][out_key] = result[base_key].get(inner_key, default)
1367
+
1368
+ return out
1369
+
1370
+
1371
+ def _parse_pdb(filename: str) -> Dict:
1372
+ """
1373
+ Parses a Protein Data Bank (PDB) file and returns the data
1374
+ in a dictionary format.
1375
+
1376
+ Parameters
1377
+ ----------
1378
+ filename : str
1379
+ The filename of the PDB file to load.
1380
+
1381
+ Returns
1382
+ -------
1383
+ dict
1384
+ A dictionary of numpy arrays. Keys are the names of the PDB
1385
+ coordinate section. In addition, some details about the parsed
1386
+ structure are included. In case of conversion failure, the failing
1387
+ attribute is set to 0 if its supposed to be an integer value.
1388
+ """
1389
+ result = PDBParser(filename)
1390
+
1391
+ atom_site_mapping = {
1392
+ "record_type": ("record_type", str),
1393
+ "atom_serial_number": ("atom_serial_number", "base-36"),
1394
+ "atom_name": ("atom_name", str),
1395
+ "alternate_location_indicator": ("alternate_location_indicator", str),
1396
+ "residue_name": ("residue_name", str),
1397
+ "chain_identifier": ("chain_identifier", str),
1398
+ "residue_sequence_number": ("residue_sequence_number", int),
1399
+ "code_for_residue_insertion": ("code_for_residue_insertion", str),
1400
+ "occupancy": ("occupancy", float),
1401
+ "temperature_factor": ("temperature_factor", float),
1402
+ "segment_identifier": ("segment_identifier", str),
1403
+ "element_symbol": ("element_symbol", str),
1404
+ "charge": ("charge", str),
1405
+ }
1406
+
1407
+ out = _convert_dtypes(result, atom_site_mapping)
1408
+
1409
+ out["metadata"] = result["details"]
1410
+ out["atom_coordinate"] = np.array(result["atom_coordinate"], dtype=np.float32)
1411
+
1412
+ return out
1413
+
1414
+
1415
+ def _parse_gro(filename):
1416
+ result = GROParser(filename)
1417
+
1418
+ atom_site_mapping = {
1419
+ "record_type": ("record_type", str),
1420
+ "atom_serial_number": ("atom_number", int),
1421
+ "atom_name": ("atom_name", str),
1422
+ "alternate_location_indicator": ("label_alt_id", str),
1423
+ "residue_name": ("residue_name", str),
1424
+ "chain_identifier": ("segment_identifier", str),
1425
+ "residue_sequence_number": ("residue_number", int),
1426
+ "code_for_residue_insertion": ("pdbx_PDB_ins_code", str),
1427
+ "occupancy": ("occupancy", float),
1428
+ "temperature_factor": ("B_iso_or_equiv", float),
1429
+ "segment_identifier": ("segment_identifier", str),
1430
+ "element_symbol": ("type_symbol", str),
1431
+ "charge": ("pdbx_formal_charge", str),
1432
+ }
1433
+
1434
+ out = _convert_dtypes(result, atom_site_mapping)
1435
+
1436
+ unique_chains = np.unique(out["segment_identifier"])
1437
+ if len(unique_chains) > 1:
1438
+ warnings.warn(
1439
+ "Multiple GRO files detected - treating them as a single Structure. "
1440
+ "GRO file number is given by segment_identifier according to the "
1441
+ "input file. Note: You need to subset the Structure to operate on "
1442
+ "individual GRO files."
1443
+ )
1444
+
1445
+ mkeys = ("title", "time", "box_vectors")
1446
+ out["metadata"] = {key: result.get(key) for key in mkeys}
1447
+ out["atom_coordinate"] = np.asarray(result["atom_coordinate"], dtype=np.float32)
1448
+ return out
1449
+
1450
+
1451
+ def _to_pdb(structure: Structure) -> str:
1452
+ """
1453
+ Returns a PDB string representation of the structure instance.
1454
+
1455
+ Parameters
1456
+ ----------
1457
+ structure : :py:class:`Structure`
1458
+ Structure instance to serialize.
1459
+
1460
+ Returns
1461
+ -------
1462
+ str
1463
+ PDB string representation of input structure.
1464
+ """
1465
+
1466
+ def _encode(atom_num: int) -> str:
1467
+ """Format atom number for PDB output."""
1468
+ if atom_num >= 10**5:
1469
+ # int("A0000", 36) - 100000
1470
+ return f"{np.base_repr(16696160 + atom_num, 36)}"
1471
+ return str(atom_num)
1472
+
1473
+ data_out = []
1474
+ for i in range(structure.atom_coordinate.shape[0]):
1475
+ x, y, z = structure.atom_coordinate[i, :]
1476
+ line = list(" " * 80)
1477
+ line[0:6] = f"{structure.record_type[i]:<6}"
1478
+ line[6:11] = f"{_encode(structure.atom_serial_number[i]):>5}"
1479
+ line[12:16] = f"{structure.atom_name[i]:<4}"
1480
+ line[16] = f"{structure.alternate_location_indicator[i]:<1}"
1481
+ line[17:20] = f"{structure.residue_name[i]:<3}"
1482
+ line[21] = f"{structure.chain_identifier[i][0]:<1}"
1483
+ line[22:26] = f"{structure.residue_sequence_number[i]:>4}"
1484
+ line[26] = f"{structure.code_for_residue_insertion[i]:<1}"
1485
+ line[30:38] = f"{x:>8.3f}"
1486
+ line[38:46] = f"{y:>8.3f}"
1487
+ line[46:54] = f"{z:>8.3f}"
1488
+ line[54:60] = f"{structure.occupancy[i]:>6.2f}"
1489
+ line[60:66] = f"{structure.temperature_factor[i]:>6.2f}"
1490
+ line[72:76] = f"{structure.segment_identifier[i]:>4}"
1491
+ line[76:78] = f"{structure.element_symbol[i]:<2}"
1492
+ line[78:80] = f"{structure.charge[i]:>2}"
1493
+ data_out.append("".join(line))
1494
+ data_out.append("END")
1495
+ return "\n".join(data_out)
1496
+
1497
+
1498
+ def _to_mmcif(structure: Structure) -> str:
1499
+ """
1500
+ Returns a MMCIF string representation of the structure instance.
1501
+
1502
+ Parameters
1503
+ ----------
1504
+ structure : :py:class:`Structure`
1505
+ Structure instance to serialize.
1506
+
1507
+ Returns
1508
+ -------
1509
+ str
1510
+ MMCIF string representation of structure.
1511
+ """
1512
+ model_num, entity_id = 1, 1
1513
+ data = {
1514
+ "group_PDB": [],
1515
+ "id": [],
1516
+ "type_symbol": [],
1517
+ "label_atom_id": [],
1518
+ "label_alt_id": [],
1519
+ "label_comp_id": [],
1520
+ "label_asym_id": [],
1521
+ "label_entity_id": [],
1522
+ "label_seq_id": [],
1523
+ "pdbx_PDB_ins_code": [],
1524
+ "Cartn_x": [],
1525
+ "Cartn_y": [],
1526
+ "Cartn_z": [],
1527
+ "occupancy": [],
1528
+ "B_iso_or_equiv": [],
1529
+ "pdbx_formal_charge": [],
1530
+ "auth_seq_id": [],
1531
+ "auth_comp_id": [],
1532
+ "auth_asym_id": [],
1533
+ "auth_atom_id": [],
1534
+ "pdbx_PDB_model_num": [],
1535
+ }
1536
+
1537
+ for index in range(structure.atom_coordinate.shape[0]):
1538
+ x, y, z = structure.atom_coordinate[index, :]
1539
+ data["group_PDB"].append(structure.record_type[index])
1540
+ data["id"].append(str(structure.atom_serial_number[index]))
1541
+ data["type_symbol"].append(structure.element_symbol[index])
1542
+ data["label_atom_id"].append(structure.atom_name[index])
1543
+ data["label_alt_id"].append(structure.alternate_location_indicator[index])
1544
+ data["label_comp_id"].append(structure.residue_name[index])
1545
+ data["label_asym_id"].append(structure.chain_identifier[index])
1546
+ data["label_entity_id"].append(str(entity_id))
1547
+ data["label_seq_id"].append(str(structure.residue_sequence_number[index]))
1548
+ data["pdbx_PDB_ins_code"].append(structure.code_for_residue_insertion[index])
1549
+ data["Cartn_x"].append(f"{x:.3f}")
1550
+ data["Cartn_y"].append(f"{y:.3f}")
1551
+ data["Cartn_z"].append(f"{z:.3f}")
1552
+ data["occupancy"].append(f"{structure.occupancy[index]:.2f}")
1553
+ data["B_iso_or_equiv"].append(f"{structure.temperature_factor[index]:.2f}")
1554
+ data["pdbx_formal_charge"].append(structure.charge[index])
1555
+ data["auth_seq_id"].append(str(structure.residue_sequence_number[index]))
1556
+ data["auth_comp_id"].append(structure.residue_name[index])
1557
+ data["auth_asym_id"].append(structure.chain_identifier[index])
1558
+ data["auth_atom_id"].append(structure.atom_name[index])
1559
+ data["pdbx_PDB_model_num"].append(str(model_num))
1560
+
1561
+ output_data = {"atom_site": data}
1562
+ original_file = structure.metadata.get("filepath", "")
1563
+ try:
1564
+ new_data = {k: v for k, v in MMCIFParser(original_file).items()}
1565
+ index = structure.atom_serial_number - 1
1566
+ new_data["atom_site"] = {
1567
+ k: [v[i] for i in index] for k, v in new_data["atom_site"].items()
1568
+ }
1569
+ new_data["atom_site"]["Cartn_x"] = data["Cartn_x"]
1570
+ new_data["atom_site"]["Cartn_y"] = data["Cartn_y"]
1571
+ new_data["atom_site"]["Cartn_z"] = data["Cartn_z"]
1572
+ output_data = new_data
1573
+ except Exception:
1574
+ pass
1575
+
1576
+ ret = ""
1577
+ for category, subdict in output_data.items():
1578
+ if not len(subdict):
1579
+ continue
1580
+
1581
+ ret += "#\n"
1582
+ is_loop = isinstance(subdict[list(subdict.keys())[0]], list)
1583
+ if not is_loop:
1584
+ for k in subdict:
1585
+ ret += f"_{category}.{k}\t{subdict[k]}\n"
1586
+ else:
1587
+ ret += "loop_\n"
1588
+ ret += "".join([f"_{category}.{k}\n" for k in subdict])
1589
+
1590
+ subdict = {k: [_format_string(s) for s in v] for k, v in subdict.items()}
1591
+ key_length = {
1592
+ key: len(max(value, key=lambda x: len(x), default=""))
1593
+ for key, value in subdict.items()
1594
+ }
1595
+ padded_subdict = {
1596
+ key: [s.ljust(key_length[key] + 1) for s in values]
1597
+ for key, values in subdict.items()
1598
+ }
1599
+
1600
+ data = [
1601
+ "".join([str(x) for x in content])
1602
+ for content in zip(*padded_subdict.values())
1603
+ ]
1604
+ ret += "\n".join([entry for entry in data]) + "\n"
1605
+ return ret
1606
+
1607
+
1608
+ def _to_gro(structure: Structure) -> str:
1609
+ """
1610
+ Generate a GRO format string representation of the structure.
1611
+
1612
+ Parameters
1613
+ ----------
1614
+ structure : :py:class:`Structure`
1615
+ Structure instance to serialize.
1616
+
1617
+ Returns
1618
+ -------
1619
+ str
1620
+ GRO string representation of structure.
1621
+ """
1622
+ ret = ""
1623
+ gro_files = np.unique(structure.segment_identifier)
1624
+ for index, gro_file in enumerate(gro_files):
1625
+ subset = structure[structure.segment_identifier == gro_file]
1626
+
1627
+ title = structure.metadata.get("title", "Missing title")
1628
+ box_vectors = structure.metadata.get("box_vectors")
1629
+ try:
1630
+ title = title[index]
1631
+ box_vectors = box_vectors[index]
1632
+ except Exception:
1633
+ pass
1634
+
1635
+ if box_vectors is None:
1636
+ box_vectors = [0.0, 0.0, 0.0]
1637
+
1638
+ num_atoms = subset.atom_coordinate.shape[0]
1639
+ lines = [title, f"{num_atoms}"]
1640
+ for i in range(num_atoms):
1641
+ res_num = subset.residue_sequence_number[i]
1642
+ res_name = subset.residue_name[i]
1643
+ atom_name = subset.atom_name[i]
1644
+ atom_num = subset.atom_serial_number[i]
1645
+
1646
+ x, y, z = subset.atom_coordinate[i]
1647
+ coord = f"{atom_num % 100000:5d}{x:8.3f}{y:8.3f}{z:8.3f}"
1648
+ line = f"{res_num % 100000:5d}{res_name:5s}{atom_name:5s}{coord}"
1649
+
1650
+ if "velocity" in subset.metadata:
1651
+ vx, vy, vz = subset.metadata["velocity"][i]
1652
+ line += f"{vx:8.4f}{vy:8.4f}{vz:8.4f}"
1653
+
1654
+ lines.append(line)
1655
+
1656
+ lines.append(" ".join(f"{v:.5f}" for v in box_vectors))
1657
+ ret += "\n".join(lines) + "\n"
1658
+ return ret
1659
+
1660
+
1691
1661
  @dataclass(frozen=True, repr=True)
1692
1662
  class _Elements:
1693
1663
  """Lookup table for chemical elements."""