molscene 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
molscene/Scene.py ADDED
@@ -0,0 +1,1296 @@
1
+ """
2
+ Python library to allow easy handling of coordinate files for molecular dynamics using pandas DataFrames.
3
+ """
4
+
5
+
6
+ import pandas
7
+ import numpy as np
8
+ import io
9
+ from typing import Union, Tuple, Sequence, List
10
+ import re
11
+ from scipy.spatial import cKDTree, distance
12
+ import logging
13
+ from . import utils
14
+
15
+
16
+
17
+
18
+ __author__ = 'Carlos Bueno'
19
+
20
+ _protein_residues = {'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E',
21
+ 'PHE': 'F', 'GLY': 'G', 'HIS': 'H', 'ILE': 'I',
22
+ 'LYS': 'K', 'LEU': 'L', 'MET': 'M', 'ASN': 'N',
23
+ 'PRO': 'P', 'GLN': 'Q', 'ARG': 'R', 'SER': 'S',
24
+ 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y'}
25
+
26
+ _DNA_residues = {'DA': 'A', 'DC': 'C', 'DG': 'G', 'DT': 'T'}
27
+
28
+ _RNA_residues = {'A': 'A', 'C': 'C', 'G': 'G', 'U': 'U'}
29
+
30
+ class _FrameAccessor:
31
+ def __init__(self, scene: "Scene"):
32
+ self._scene = scene
33
+
34
+ def __getitem__(self, index):
35
+ # Retrieve the multi-frame array from the parent Scene.
36
+ frames = self._scene.get_coordinate_frames()
37
+ # Support integer indexing (or slicing that returns one or more frames)
38
+ new_coords = frames[index]
39
+ # If a single frame is selected, new_coords has shape (n_atoms, 3).
40
+ # In that case, create a new Scene that has the same metadata but with
41
+ # the coordinates replaced by this frame. Importantly, we do NOT copy
42
+ # the entire multi-frame data.
43
+ if new_coords.ndim == 2:
44
+ new_scene = self._scene.copy(deep=True)
45
+ # Remove the heavy multi-frame data from the new scene.
46
+ new_scene._meta.pop('coordinate_frames', None)
47
+ new_scene.set_coordinates(new_coords)
48
+ return new_scene
49
+ # If the index returns multiple frames (e.g. a slice), return a list
50
+ # of Scene objects, one per frame.
51
+ elif new_coords.ndim == 3:
52
+ scenes = []
53
+ for coords in new_coords:
54
+ new_scene = self._scene.copy(deep=True)
55
+ new_scene._meta.pop('coordinate_frames', None)
56
+ new_scene.set_coordinates(coords)
57
+ scenes.append(new_scene)
58
+ return scenes
59
+ else:
60
+ raise ValueError("Invalid frame dimensions")
61
+
62
+ def __iter__(self):
63
+ frames = self._scene.get_coordinate_frames()
64
+ for i in range(frames.shape[0]):
65
+ yield self[i]
66
+
67
+
68
+ class Scene(pandas.DataFrame):
69
+
70
+ _columns = {'recname': 'Record name',
71
+ 'serial': 'Atom serial number',
72
+ 'name': 'Atom name',
73
+ 'altLoc': 'Alternate location indicator',
74
+ 'resName': 'Residue name',
75
+ 'chainID': 'Chain identifier',
76
+ 'resSeq': 'Residue sequence number',
77
+ 'iCode': 'Code for insertion of residues',
78
+ 'x': 'Orthogonal coordinates for X in Angstroms',
79
+ 'y': 'Orthogonal coordinates for Y in Angstroms',
80
+ 'z': 'Orthogonal coordinates for Z in Angstroms',
81
+ 'occupancy': 'Occupancy',
82
+ 'tempFactor': 'Temperature factor',
83
+ 'element': 'Element symbol',
84
+ 'charge': 'Charge on the atom',
85
+ 'model': 'Model number',
86
+ # 'res_index': 'Residue index',
87
+ # 'chain_index': 'Chain index',
88
+ 'molecule': 'Molecule name',
89
+ 'resname': 'Residue name'}
90
+
91
+
92
+ # Initialization
93
+ def __init__(self, particles, altLoc='A', model=1, **kwargs):
94
+ """Create an empty scene from particles.
95
+ The Scene object is a wraper of a pandas DataFrame with extra information"""
96
+ super().__init__(particles)
97
+ # Add metadata dictionary
98
+ self.__dict__['_meta'] = {}
99
+
100
+ if all([col in self.columns for col in ['x', 'y', 'z']]):
101
+ pass
102
+ elif any([col in self.columns for col in ['x', 'y', 'z']]):
103
+ raise ValueError(f"Incomplete coordinates, missing columns: {set(['x', 'y', 'z']) - set(self.columns)}")
104
+ elif len(self.columns) == 3:
105
+ self.columns=['x', 'y', 'z']
106
+ else:
107
+ raise ValueError("Incorrect particle format")
108
+
109
+ if 'chainID' not in self.columns:
110
+ self['chainID'] = ['A'] * len(self)
111
+ if 'resSeq' not in self.columns:
112
+ self['resSeq'] = [1] * len(self)
113
+ if 'iCode' not in self.columns:
114
+ self['iCode'] = [''] * len(self)
115
+ if 'altLoc' not in self.columns:
116
+ self['altLoc'] = [''] * len(self)
117
+ if 'model' not in self.columns:
118
+ self['model'] = [1] * len(self)
119
+ if 'name' not in self.columns:
120
+ self['name'] = [f'P{i:03}' for i in range(len(self))]
121
+ if 'element' not in self.columns:
122
+ self['element'] = ['C'] * len(self)
123
+ if 'occupancy' not in self.columns:
124
+ self['occupancy'] = [1.0] * len(self)
125
+ if 'tempFactor' not in self.columns:
126
+ self['tempFactor'] = [1.0] * len(self)
127
+ if 'resName' not in self.columns:
128
+ self['resName'] = [''] * len(self)
129
+
130
+ # Create an integer index for the chains
131
+ if 'chain_index' not in self.columns:
132
+ chain_map = {b: a for a, b in enumerate(self['chainID'].unique())}
133
+ self['chain_index'] = self['chainID'].map(chain_map).astype(int)
134
+
135
+ # Create an integer index for the residues
136
+ if 'res_index' not in self.columns:
137
+ # Construct a global unique residue key
138
+ residue_keys = (
139
+ self['chain_index'].astype(str) +
140
+ self['resSeq'].astype(str) +
141
+ self['iCode'].astype(str)
142
+ )
143
+
144
+ # Get unique residue keys and map to integers
145
+ unique_keys = pandas.Series(residue_keys.unique())
146
+ key_to_index = dict(zip(unique_keys, range(len(unique_keys))))
147
+
148
+ # Map each residue key to its index
149
+ self['res_index'] = residue_keys.map(key_to_index).astype(int)
150
+
151
+ # Create an integer index for the atoms
152
+ if 'atom_index' not in self.columns:
153
+ self['atom_index'] = range(len(self))
154
+
155
+ # Add metadata
156
+ for attr, value in kwargs.items():
157
+ self._meta[attr] = value
158
+
159
+ def set_coordinate_frames(self, frames: np.ndarray):
160
+ """
161
+ Set the coordinate frames from a NumPy array.
162
+
163
+ Parameters
164
+ ----------
165
+ frames : np.ndarray
166
+ A NumPy array of shape (n_frames, n_atoms, 3).
167
+
168
+ Raises
169
+ ------
170
+ TypeError
171
+ If frames is not a NumPy array.
172
+ ValueError
173
+ If the array does not have three dimensions or the last dimension is not 3,
174
+ or if the number of atoms (second dimension) does not match the number of rows.
175
+ """
176
+ if not isinstance(frames, np.ndarray):
177
+ raise TypeError("frames must be a numpy array")
178
+ if frames.ndim != 3 or frames.shape[2] != 3:
179
+ raise ValueError("frames must be a 3D numpy array with shape (n_frames, n_atoms, 3)")
180
+ if frames.shape[1] != len(self):
181
+ raise ValueError("The number of atoms in frames must match the number of rows in the Scene")
182
+ self._meta['coordinate_frames'] = frames
183
+ # Update the current coordinates to the first frame.
184
+ self.set_coordinates(frames[0])
185
+
186
+ def get_coordinate_frames(self) -> np.ndarray:
187
+ """
188
+ Retrieve the multi-frame coordinates.
189
+
190
+ Returns
191
+ -------
192
+ np.ndarray
193
+ A NumPy array of shape (n_frames, n_atoms, 3). If no frames have been set,
194
+ the current single-frame coordinates are returned with shape (1, n_atoms, 3).
195
+ """
196
+ if 'coordinate_frames' in self._meta:
197
+ return self._meta['coordinate_frames']
198
+ else:
199
+ return self.get_coordinates().to_numpy().reshape(1, -1, 3)
200
+
201
+ @property
202
+ def n_frames(self) -> int:
203
+ """
204
+ Number of frames stored in the coordinate frames.
205
+
206
+ Returns
207
+ -------
208
+ int
209
+ The number of frames.
210
+ """
211
+ return self.get_coordinate_frames().shape[0]
212
+
213
+ @property
214
+ def frames(self) -> _FrameAccessor:
215
+ """
216
+ Accessor to select individual frames.
217
+
218
+ Example
219
+ -------
220
+ >>> frame10 = scene.frames[10]
221
+ """
222
+ return _FrameAccessor(self)
223
+
224
+ def iterframes(self):
225
+ """
226
+ Iterate over frames.
227
+
228
+ Yields
229
+ ------
230
+ Scene
231
+ A new Scene for each frame (with the coordinates replaced).
232
+ """
233
+ return iter(self.frames)
234
+
235
+ def get_frame_coordinates(self, frame_index: int) -> np.ndarray:
236
+ """
237
+ Get the coordinates for a particular frame.
238
+
239
+ Parameters
240
+ ----------
241
+ frame_index : int
242
+ The index of the desired frame.
243
+
244
+ Returns
245
+ -------
246
+ np.ndarray
247
+ An array of shape (n_atoms, 3) for that frame.
248
+ """
249
+ frames = self.get_coordinate_frames()
250
+ return frames[frame_index]
251
+
252
+ def set_frame_coordinates(self, frame_index: int):
253
+ """
254
+ Set the Scene’s current coordinates to those of a specific frame.
255
+
256
+ Parameters
257
+ ----------
258
+ frame_index : int
259
+ The index of the frame to set as current.
260
+ """
261
+ frames = self.get_coordinate_frames()
262
+ self.set_coordinates(frames[frame_index])
263
+
264
+ def select(self, **kwargs):
265
+ index = self.index
266
+ sel = pandas.Series([True] * len(index), index=index)
267
+ for key in kwargs:
268
+ if key == 'altLoc':
269
+ sel &= (self['altLoc'].isin(['', '.'] + kwargs['altLoc']))
270
+ elif key == 'model':
271
+ sel &= (self['model'].isin(kwargs['model']))
272
+ else:
273
+ sel &= (self[key].isin(kwargs[key]))
274
+
275
+ # Assert there are not repeated atoms
276
+ index = self[sel][['chain_index', 'res_index', 'name']]
277
+ if len(index.duplicated()) == 0:
278
+ print("Duplicated atoms found")
279
+ print(index[index.duplicated()])
280
+ self._meta['duplicated'] = True
281
+
282
+ return Scene(self[sel], **self._meta)
283
+
284
+ def split_models(self):
285
+ # TODO: Implement splitting based on model and altLoc.
286
+ # altLoc can be present in multiple regions (1zir)
287
+ pass
288
+
289
+ # for m in self['model'].unique():
290
+ # for a in sel:
291
+ # pass
292
+
293
+ @classmethod
294
+ def from_pdb(cls, file, **kwargs):
295
+ def pdb_line(line):
296
+ l = dict(recname=line[0:6].strip(),
297
+ serial=line[6:11],
298
+ name=line[12:16].strip(),
299
+ altLoc=line[16:17].strip(),
300
+ resName=line[17:20].strip(),
301
+ chainID=line[21:22].strip(),
302
+ resSeq=line[22:26],
303
+ iCode=line[26:27].strip(),
304
+ x=line[30:38],
305
+ y=line[38:46],
306
+ z=line[46:54],
307
+ occupancy=line[54:60].strip(),
308
+ tempFactor=line[60:66].strip(),
309
+ element=line[76:78].strip(),
310
+ charge=line[78:80].strip())
311
+ return l
312
+
313
+ with open(file, 'r') as pdb:
314
+ lines = []
315
+ mod_lines = []
316
+ model_numbers = []
317
+ model_number = 1
318
+ for i, line in enumerate(pdb):
319
+ if len(line) > 6:
320
+ header = line[:6]
321
+ if header == 'ATOM ' or header == 'HETATM':
322
+ try:
323
+ lines += [pdb_line(line)]
324
+ except ValueError as e:
325
+ print(e)
326
+ print(f"Error in line {i}")
327
+ print(line)
328
+ raise ValueError
329
+ model_numbers += [model_number]
330
+ elif header == "MODRES":
331
+ m = dict(recname=str(line[0:6]).strip(),
332
+ idCode=str(line[7:11]).strip(),
333
+ resName=str(line[12:15]).strip(),
334
+ chainID=str(line[16:17]).strip(),
335
+ resSeq=int(line[18:22]),
336
+ iCode=str(line[22:23]).strip(),
337
+ stdRes=str(line[24:27]).strip(),
338
+ comment=str(line[29:70]).strip())
339
+ mod_lines += [m]
340
+ elif header == "MODEL ":
341
+ model_number = int(line[10:14])
342
+ pdb_atoms = pandas.DataFrame(lines)
343
+ pdb_atoms = pdb_atoms[['recname', 'serial', 'name', 'altLoc',
344
+ 'resName', 'chainID', 'resSeq', 'iCode',
345
+ 'x', 'y', 'z', 'occupancy', 'tempFactor',
346
+ 'element', 'charge']]
347
+
348
+ # Apply type conversions and set default values
349
+ pdb_atoms['serial'] = pandas.to_numeric(pdb_atoms['serial'], errors='coerce').fillna(0).astype(int)
350
+ pdb_atoms['resSeq'] = pandas.to_numeric(pdb_atoms['resSeq'], errors='coerce').fillna(0).astype(int)
351
+ pdb_atoms['x'] = pandas.to_numeric(pdb_atoms['x'], errors='coerce').fillna(0.0)
352
+ pdb_atoms['y'] = pandas.to_numeric(pdb_atoms['y'], errors='coerce').fillna(0.0)
353
+ pdb_atoms['z'] = pandas.to_numeric(pdb_atoms['z'], errors='coerce').fillna(0.0)
354
+ pdb_atoms['occupancy'] = pandas.to_numeric(pdb_atoms['occupancy'], errors='coerce').fillna(1.0)
355
+ pdb_atoms['tempFactor'] = pandas.to_numeric(pdb_atoms['tempFactor'], errors='coerce').fillna(1.0)
356
+ pdb_atoms['charge'] = pandas.to_numeric(pdb_atoms['tempFactor'], errors='coerce').fillna(0.0)
357
+ pdb_atoms['model'] = model_numbers
358
+ pdb_atoms['molecule'] = 0
359
+
360
+ if len(mod_lines) > 0:
361
+ kwargs.update(dict(modified_residues=pandas.DataFrame(mod_lines)))
362
+
363
+ return cls(pdb_atoms, **kwargs)
364
+
365
+ @classmethod
366
+ def from_cif(cls, file_path, **kwargs):
367
+ """
368
+ Extracts only the _atom section from an mmCIF file.
369
+
370
+ Args:
371
+ file_path (str): Path to the CIF file.
372
+
373
+ Returns:
374
+ list: List of parsed atom data rows.
375
+ """
376
+
377
+ atom_data = []
378
+ atom_header = []
379
+ in_atom_section = False
380
+ tokenizer = re.compile(r"""'[^']*' | # single-quoted
381
+ "[^"]*" | # double-quoted
382
+ \#[^\n]* | # comment
383
+ [^\s'"#]+ # unquoted
384
+ """, re.VERBOSE)
385
+
386
+ with open(file_path, 'r') as file:
387
+ for line in file:
388
+ line = line.strip()
389
+
390
+ # Skip empty lines and comments
391
+ if not line or line.startswith("#"):
392
+ continue
393
+
394
+ # Detect the start of the _atom section
395
+ if line.startswith("loop_"):
396
+ in_atom_section = False # Reset section flag
397
+
398
+ elif line.startswith("_atom_site."):
399
+ atom_header.append(line.split('.')[-1])
400
+ in_atom_section = True # Found relevant section, start collecting headers
401
+
402
+ elif in_atom_section:
403
+ atom_data.append([
404
+ token.strip("'\"") # strip any surrounding quotes
405
+ for token in tokenizer.findall(line)
406
+ if not token.startswith('#') # drop the comment token (and everything after)
407
+ ])
408
+
409
+ cif_atoms = pandas.DataFrame(atom_data,columns=atom_header)
410
+
411
+ # Rename columns to pdb convention
412
+ _cif_pdb_rename = {'id': 'serial',
413
+ 'label_atom_id': 'name',
414
+ 'label_alt_id': 'altLoc',
415
+ 'label_comp_id': 'resName',
416
+ 'label_asym_id': 'chainID',
417
+ 'label_seq_id': 'resSeq',
418
+ 'pdbx_PDB_ins_code': 'iCode',
419
+ 'Cartn_x': 'x',
420
+ 'Cartn_y': 'y',
421
+ 'Cartn_z': 'z',
422
+ 'occupancy': 'occupancy',
423
+ 'B_iso_or_equiv': 'tempFactor',
424
+ 'type_symbol': 'element',
425
+ 'pdbx_formal_charge': 'charge',
426
+ 'pdbx_PDB_model_num': 'model'}
427
+
428
+ cif_atoms = cif_atoms.rename(_cif_pdb_rename, axis=1)
429
+ for col in cif_atoms.columns:
430
+ try:
431
+ cif_atoms[col] = cif_atoms[col].astype(float)
432
+ if ((cif_atoms[col].astype(int) - cif_atoms[col]) ** 2).sum() == 0:
433
+ cif_atoms[col] = cif_atoms[col].astype(int)
434
+ continue
435
+ except ValueError:
436
+ pass
437
+
438
+ cif_atoms['serial'] = pandas.to_numeric(cif_atoms['serial'], errors='coerce').fillna(0).astype(int)
439
+ cif_atoms['resSeq'] = pandas.to_numeric(cif_atoms['resSeq'], errors='coerce').fillna(0).astype(int)
440
+ cif_atoms['x'] = pandas.to_numeric(cif_atoms['x'], errors='coerce').fillna(0.0)
441
+ cif_atoms['y'] = pandas.to_numeric(cif_atoms['y'], errors='coerce').fillna(0.0)
442
+ cif_atoms['z'] = pandas.to_numeric(cif_atoms['z'], errors='coerce').fillna(0.0)
443
+ cif_atoms['occupancy'] = pandas.to_numeric(cif_atoms['occupancy'], errors='coerce').fillna(1.0)
444
+ cif_atoms['tempFactor'] = pandas.to_numeric(cif_atoms['tempFactor'], errors='coerce').fillna(1.0)
445
+ cif_atoms['charge'] = pandas.to_numeric(cif_atoms['tempFactor'], errors='coerce').fillna(0.0)
446
+
447
+ return cls(cif_atoms, **kwargs)
448
+
449
+ @classmethod
450
+ def from_gro(cls, gro, **kwargs):
451
+ raise NotImplementedError
452
+
453
+ @classmethod
454
+ def from_fixPDB(cls, filename=None, pdbfile=None, pdbxfile=None, url=None, pdbid=None,
455
+ **kwargs):
456
+ """Uses the pdbfixer library to fix a pdb file, replacing non standard residues, removing
457
+ hetero-atoms and adding missing hydrogens. The input is a pdb file location,
458
+ the output is a fixer object, which is a pdb in the openawsem format."""
459
+ import pdbfixer
460
+
461
+ filename=str(filename) if filename is not None else None
462
+ pdbfile=str(pdbfile) if pdbfile is not None else None
463
+ pdbxfile=str(pdbxfile) if pdbxfile is not None else None
464
+ url=str(url) if url is not None else None
465
+ pdbid=str(pdbid) if pdbid is not None else None
466
+
467
+ fixer = pdbfixer.PDBFixer(filename=filename, pdbfile=pdbfile, pdbxfile=pdbxfile, url=url, pdbid=pdbid)
468
+ fixer.findMissingResidues()
469
+ chains = list(fixer.topology.chains())
470
+ keys = fixer.missingResidues.keys()
471
+ for key in list(keys):
472
+ chain_tmp = chains[key[0]]
473
+ if key[1] == 0 or key[1] == len(list(chain_tmp.residues())):
474
+ del fixer.missingResidues[key]
475
+
476
+ fixer.findNonstandardResidues()
477
+ fixer.replaceNonstandardResidues()
478
+ fixer.removeHeterogens(keepWater=False)
479
+ fixer.findMissingAtoms()
480
+ fixer.addMissingAtoms() # Warning: importing 'simtk.openmm' is deprecated. Import 'openmm' instead.
481
+ fixer.addMissingHydrogens(7.0)
482
+
483
+ pdb = fixer
484
+ """ Parses a pdb in the openmm format and outputs a table that contains all the information
485
+ on a pdb file """
486
+ cols = ['recname', 'serial', 'name', 'altLoc',
487
+ 'resName', 'chainID', 'resSeq', 'iCode',
488
+ 'x', 'y', 'z', 'occupancy', 'tempFactor',
489
+ 'element', 'charge']
490
+ data = []
491
+
492
+ for atom, pos in zip(pdb.topology.atoms(), pdb.positions):
493
+ residue = atom.residue
494
+ chain = residue.chain
495
+ pos = pos.value_in_unit(pdbfixer.pdbfixer.unit.angstrom)
496
+ data += [dict(zip(cols, ['ATOM', int(atom.id), atom.name, '',
497
+ residue.name, chain.id, int(residue.id), '',
498
+ pos[0], pos[1], pos[2], 0, 0,
499
+ atom.element.symbol, '']))]
500
+ atom_list = pandas.DataFrame(data)
501
+ atom_list = atom_list[cols]
502
+ atom_list.index = atom_list['serial']
503
+ return cls(atom_list, **kwargs)
504
+
505
+ @classmethod
506
+ def from_fixer(cls, fixer, **kwargs):
507
+ import pdbfixer
508
+ pdb = fixer
509
+ """ Parses a pdb in the openmm format and outputs a table that contains all the information
510
+ on a pdb file """
511
+ cols = ['recname', 'serial', 'name', 'altLoc',
512
+ 'resName', 'chainID', 'resSeq', 'iCode',
513
+ 'x', 'y', 'z', 'occupancy', 'tempFactor',
514
+ 'element', 'charge']
515
+ data = []
516
+
517
+ for atom, pos in zip(pdb.topology.atoms(), pdb.positions):
518
+ residue = atom.residue
519
+ chain = residue.chain
520
+ pos = pos.value_in_unit(pdbfixer.pdbfixer.unit.angstrom)
521
+ data += [dict(zip(cols, ['ATOM', int(atom.id), atom.name, '',
522
+ residue.name, chain.id, int(residue.id), '',
523
+ pos[0], pos[1], pos[2], 0, 0,
524
+ atom.element.symbol, '']))]
525
+ atom_list = pandas.DataFrame(data)
526
+ atom_list = atom_list[cols]
527
+ atom_list.index = atom_list['serial']
528
+ return cls(atom_list, **kwargs)
529
+
530
+ @classmethod
531
+ def from_file(cls, filename):
532
+ if filename.endswith('.pdb'):
533
+ return cls.from_pdb(filename)
534
+ elif filename.endswith('.cif'):
535
+ return cls.from_cif(filename)
536
+ elif filename.endswith('.gro'):
537
+ return cls.from_gro(filename)
538
+ else:
539
+ raise ValueError('Unknown file format')
540
+
541
+ def to_file(self, filename):
542
+ if filename.endswith('.pdb'):
543
+ self.write_pdb(filename)
544
+ elif filename.endswith('.cif'):
545
+ self.write_cif(filename)
546
+ elif filename.endswith('.gro'):
547
+ self.write_gro(filename)
548
+ else:
549
+ raise ValueError('Unknown file format')
550
+
551
+ @classmethod
552
+ def concatenate(cls, scene_list):
553
+ #Set chain names
554
+ chainID = []
555
+ name_generator = utils.chain_name_generator()
556
+ for scene in scene_list:
557
+ if 'chainID' not in scene:
558
+ chainID += [next(name_generator)]*len(scene)
559
+ else:
560
+ chains = list(scene['chainID'].unique())
561
+ chains.sort()
562
+ chain_replace = {chain: next(name_generator) for chain in chains}
563
+ chainID += list(scene['chainID'].replace(chain_replace))
564
+ name_generator.close()
565
+ model = pandas.concat(scene_list)
566
+ model['chainID'] = chainID
567
+ model.index = range(len(model))
568
+ return cls(model)
569
+
570
+ # Writing
571
+ def write_pdb(self, file_name=None, verbose=False):
572
+
573
+ # TODO Add connectivity output
574
+ # Fill empty columns
575
+ if verbose:
576
+ print(f"Writing pdb file ({len(self)} atoms): {file_name}")
577
+
578
+ pdb_table = self.copy()
579
+ pdb_table['serial'] = np.arange(1, len(self) + 1) if 'serial' not in pdb_table else pdb_table['serial']
580
+ pdb_table['name'] = 'A' if 'name' not in pdb_table else pdb_table['name']
581
+ pdb_table['altLoc'] = '' if 'altLoc' not in pdb_table else pdb_table['altLoc']
582
+ pdb_table['resName'] = 'R' if 'resName' not in pdb_table else pdb_table['resName']
583
+ pdb_table['chainID'] = 'C' if 'chainID' not in pdb_table else pdb_table['chainID']
584
+ pdb_table['resSeq'] = 1 if 'resSeq' not in pdb_table else pdb_table['resSeq']
585
+ pdb_table['iCode'] = '' if 'iCode' not in pdb_table else pdb_table['iCode']
586
+ assert 'x' in pdb_table.columns, 'Coordinate x not in particle definition'
587
+ assert 'y' in pdb_table.columns, 'Coordinate x not in particle definition'
588
+ assert 'z' in pdb_table.columns, 'Coordinate x not in particle definition'
589
+ pdb_table['occupancy'] = 0 if 'occupancy' not in pdb_table else pdb_table['occupancy']
590
+ pdb_table['tempFactor'] = 0 if 'tempFactor' not in pdb_table else pdb_table['tempFactor']
591
+ pdb_table['element'] = '' if 'element' not in pdb_table else pdb_table['element']
592
+ pdb_table['charge'] = 0 if 'charge' not in pdb_table else pdb_table['charge']
593
+
594
+ # Override chain names if molecule is present
595
+ if 'molecule' in pdb_table:
596
+ cc = utils.chain_name_generator(format='pdb')
597
+ molecules = self['molecule'].unique()
598
+ cc_d = dict(zip(molecules, cc))
599
+ # cc_d = dict(zip(range(1, len(cc) + 1), cc))
600
+ pdb_table['chainID'] = self['molecule'].replace(cc_d)
601
+
602
+ # Write pdb file
603
+ lines = ''
604
+ for i, atom in pdb_table.iterrows():
605
+ line = f'ATOM {i%100000:>5} {atom["name"]:^4} {atom["resName"]:<3} {atom["chainID"]}{atom["resSeq"]:>4}' + \
606
+ ' ' + \
607
+ f'{atom.x:>8.3f}{atom.y:>8.3f}{atom.z:>8.3f}' + ' ' * 22 + f'{atom.element:2}' + ' ' * 2
608
+ assert len(line) == 80, f'An item in the atom table is longer than expected\n{line}'
609
+ lines += line + '\n'
610
+
611
+ if file_name is None:
612
+ return io.StringIO(lines)
613
+ else:
614
+ with open(file_name, 'w+') as out:
615
+ out.write(lines)
616
+
617
+ def write_cif(self, file_name=None, verbose=False):
618
+ """Write a PDBx/mmCIF file.
619
+
620
+ Parameters
621
+ ----------
622
+ topology : Topology
623
+ The Topology defining the molecular system being written
624
+ file : file=stdout
625
+ A file to write the file to
626
+ entry : str=None
627
+ The entry ID to assign to the CIF file
628
+ keepIds : bool=False
629
+ If True, keep the residue and chain IDs specified in the Topology
630
+ rather than generating new ones. Warning: It is up to the caller to
631
+ make sure these are valid IDs that satisfy the requirements of the
632
+ PDBx/mmCIF format. Otherwise, the output file will be invalid.
633
+ """
634
+ """Write out a model to a PDBx/mmCIF file.
635
+
636
+ Parameters
637
+ ----------
638
+ topology : Topology
639
+ The Topology defining the model to write
640
+ positions : list
641
+ The list of atomic positions to write
642
+ file : file=stdout
643
+ A file to write the model to
644
+ modelIndex : int=1
645
+ The model number of this frame
646
+ keepIds : bool=False
647
+ If True, keep the residue and chain IDs specified in the Topology
648
+ rather than generating new ones. Warning: It is up to the caller to
649
+ make sure these are valid IDs that satisfy the requirements of the
650
+ PDBx/mmCIF format. Otherwise, the output file will be invalid.
651
+ """
652
+ # TODO Add connectivity output
653
+ if verbose:
654
+ print(f"Writing cif file ({len(self)} atoms): {file_name}")
655
+
656
+ # Fill empty columns
657
+ pdbx_table = self.copy()
658
+ pdbx_table['serial'] = np.arange(1, len(self) + 1) if 'serial' not in pdbx_table else pdbx_table['serial']
659
+ pdbx_table['name'] = 'A' if 'name' not in pdbx_table else pdbx_table['name']
660
+ pdbx_table['altLoc'] = '?' if 'altLoc' not in pdbx_table else pdbx_table['altLoc']
661
+ pdbx_table['resName'] = 'R' if 'resName' not in pdbx_table else pdbx_table['resName']
662
+ pdbx_table['chainID'] = 'C' if 'chainID' not in pdbx_table else pdbx_table['chainID']
663
+ pdbx_table['resSeq'] = 1 if 'resSeq' not in pdbx_table else pdbx_table['resSeq']
664
+ pdbx_table['resIC'] = 1 if 'resIC' not in pdbx_table else pdbx_table['resIC']
665
+ pdbx_table['iCode'] = '' if 'iCode' not in pdbx_table else pdbx_table['iCode']
666
+ assert 'x' in pdbx_table.columns, 'Coordinate x not in particle definition'
667
+ assert 'y' in pdbx_table.columns, 'Coordinate x not in particle definition'
668
+ assert 'z' in pdbx_table.columns, 'Coordinate x not in particle definition'
669
+ pdbx_table['occupancy'] = 0 if 'occupancy' not in pdbx_table else pdbx_table['occupancy']
670
+ pdbx_table['tempFactor'] = 0 if 'tempFactor' not in pdbx_table else pdbx_table['tempFactor']
671
+ pdbx_table['element'] = 'C' if 'element' not in pdbx_table else pdbx_table['element']
672
+ pdbx_table['model'] = 0 if 'model' not in pdbx_table else pdbx_table['model']
673
+
674
+ # If the column is a string convert it to a float
675
+ for col in ['serial', 'resSeq', 'resIC', 'model','charge']:
676
+ pdbx_table[col] = pandas.to_numeric(pdbx_table[col], errors='coerce').fillna(0).astype(int)
677
+ for col in ['x', 'y', 'z', 'occupancy', 'tempFactor']:
678
+ pdbx_table[col] = pandas.to_numeric(pdbx_table[col], errors='coerce').fillna(0.0)
679
+
680
+ #If the column is a string convert and empty string to a dot
681
+ for col in ['name', 'altLoc', 'resName', 'chainID', 'iCode', 'element']:
682
+ pdbx_table[col] = pdbx_table[col].str.strip().replace('', '.')
683
+
684
+ # print(pdbx_table)
685
+ # pdbx_table.fillna('.', inplace=True)
686
+ # pdbx_table.replace(' ', '.', inplace=True)
687
+
688
+ lines = ""
689
+ lines += "data_pdbx\n"
690
+ lines += "#\n"
691
+ lines += "loop_\n"
692
+ lines += "_atom_site.group_PDB\n"
693
+ lines += "_atom_site.id\n"
694
+ lines += "_atom_site.label_atom_id\n"
695
+ lines += "_atom_site.label_comp_id\n"
696
+ lines += "_atom_site.label_asym_id\n"
697
+ lines += "_atom_site.label_seq_id\n"
698
+ lines += "_atom_site.label_alt_id\n"
699
+ lines += "_atom_site.auth_atom_id\n"
700
+ lines += "_atom_site.auth_comp_id\n"
701
+ lines += "_atom_site.auth_asym_id\n"
702
+ lines += "_atom_site.auth_seq_id\n"
703
+ lines += "_atom_site.pdbx_PDB_ins_code\n"
704
+ lines += "_atom_site.Cartn_x\n"
705
+ lines += "_atom_site.Cartn_y\n"
706
+ lines += "_atom_site.Cartn_z\n"
707
+ lines += "_atom_site.occupancy\n"
708
+ lines += "_atom_site.B_iso_or_equiv\n"
709
+ lines += "_atom_site.type_symbol\n"
710
+ lines += "_atom_site.pdbx_formal_chrge\n"
711
+ lines += "_atom_site.pdbx_PDB_model_num\n"
712
+
713
+ pdbx_table['line'] = 'ATOM'
714
+
715
+ def cif_quote(val):
716
+ if val is np.nan:
717
+ return '.'
718
+ if not isinstance(val, str):
719
+ val = str(val)
720
+ if "'" in val and '"' in val:
721
+ # If both quotes are present (unusual), use double quotes and replace double quotes with single quotes
722
+ return '"' + val.replace('"', "'") + '"'
723
+ elif "'" in val:
724
+ return '"' + val + '"'
725
+ elif '"' in val:
726
+ return "'" + val + "'"
727
+ elif any(c.isspace() for c in val) or val == '' or val.startswith('#') or val.startswith(';'):
728
+ #quote the string if it contains spaces or is empty
729
+ return '"' + val + '"'
730
+ else:
731
+ return val
732
+
733
+ for col in ['serial',
734
+ 'name', 'resName', 'chainID', 'resSeq', 'iCode',
735
+ 'name', 'resName', 'chainID', 'resSeq','iCode',
736
+ 'x', 'y', 'z',
737
+ 'occupancy', 'tempFactor',
738
+ 'element', 'charge', 'model']:
739
+ pdbx_table['line'] += " "
740
+ pdbx_table['line'] += pdbx_table[col].apply(cif_quote)
741
+ pdbx_table['line'] += '\n'
742
+ lines += ''.join(pdbx_table['line'])
743
+ lines += '#\n'
744
+
745
+ if file_name is None:
746
+ return io.StringIO(lines)
747
+ else:
748
+ with open(file_name, 'w+') as out:
749
+ out.write(lines)
750
+
751
+ def write_gro(self, file_name, box_size=None, verbose=False):
752
+ """
753
+ Write the Scene to a GRO file.
754
+
755
+ Parameters:
756
+ -----------
757
+ file_name : str
758
+ Name of the output GRO file.
759
+
760
+ box_size : float or tuple of floats, optional
761
+ The box dimensions in nanometers (x, y, z). If None, it will be set based on the coordinates.
762
+
763
+ verbose : bool, optional
764
+ If True, prints additional information.
765
+
766
+ Raises:
767
+ -------
768
+ ValueError
769
+ If required columns are missing.
770
+ """
771
+ if verbose:
772
+ print(f"Writing GRO file ({len(self)} atoms): {file_name}")
773
+
774
+ # Prepare data
775
+ gro_atoms = self.copy()
776
+
777
+ # Ensure required columns are present
778
+ required_columns = ['resSeq', 'resName', 'name', 'x', 'y', 'z']
779
+ for col in required_columns:
780
+ if col not in gro_atoms.columns:
781
+ raise ValueError(f"Column '{col}' is required for writing GRO file.")
782
+
783
+ # Handle 'serial' column
784
+ if 'serial' not in gro_atoms.columns:
785
+ gro_atoms['serial'] = np.arange(1, len(gro_atoms) + 1)
786
+
787
+ # Convert types and handle formatting
788
+ gro_atoms['resSeq'] = gro_atoms['resSeq'].astype(int) % 100000 # Limit to 5 digits
789
+ gro_atoms['serial'] = gro_atoms['serial'].astype(int) % 100000 # Limit to 5 digits
790
+ gro_atoms['resName'] = gro_atoms['resName'].astype(str).str[:5]
791
+ gro_atoms['name'] = gro_atoms['name'].astype(str).str[:5]
792
+
793
+ # Divide coordinates by 10 to convert from Angstroms to nanometers
794
+ gro_atoms['x'] = gro_atoms['x'] / 10.0
795
+ gro_atoms['y'] = gro_atoms['y'] / 10.0
796
+ gro_atoms['z'] = gro_atoms['z'] / 10.0
797
+
798
+ # If box_size is not specified, set it based on the coordinates
799
+ if box_size is None:
800
+ x_max = gro_atoms['x'].max()
801
+ y_max = gro_atoms['y'].max()
802
+ z_max = gro_atoms['z'].max()
803
+ x_min = gro_atoms['x'].min()
804
+ y_min = gro_atoms['y'].min()
805
+ z_min = gro_atoms['z'].min()
806
+ # Add a buffer of 1.0 nm to each dimension
807
+ box_size = (x_max - x_min + 1.0, y_max - y_min + 1.0, z_max - z_min + 1.0)
808
+ elif isinstance(box_size, (float, int)):
809
+ box_size = (box_size, box_size, box_size)
810
+
811
+ # Start writing the GRO file
812
+ with open(file_name, 'w') as f:
813
+ f.write('Generated by Scene.write_gro\n')
814
+ f.write(f'{len(gro_atoms):5d}\n')
815
+ for _, atom in gro_atoms.iterrows():
816
+ line = f"{atom['resSeq']:5d}{atom['resName']:<5}{atom['name']:>5}{atom['serial']:5d}" \
817
+ f"{atom['x']:8.3f}{atom['y']:8.3f}{atom['z']:8.3f}\n"
818
+ f.write(line)
819
+ # Write box dimensions
820
+ f.write(f"{box_size[0]:10.5f}{box_size[1]:10.5f}{box_size[2]:10.5f}\n")
821
+
822
+ def write_gro_per_chain(self, base_filename, box_size=None, verbose=False):
823
+ """
824
+ Write each chain in the Scene to a separate GRO file.
825
+
826
+ Parameters:
827
+ -----------
828
+ base_filename : str
829
+ Base filename to use for output GRO files. The chain ID will be appended to the base filename.
830
+
831
+ box_size : float or tuple of floats, optional
832
+ The box dimensions in nanometers. If None, it will be set based on the coordinates.
833
+
834
+ verbose : bool, optional
835
+ If True, prints additional information.
836
+
837
+ Raises:
838
+ -------
839
+ ValueError
840
+ If 'chainID' column is missing.
841
+ """
842
+ if 'chainID' not in self.columns:
843
+ raise ValueError("Column 'chainID' is required to write GRO files per chain.")
844
+
845
+ unique_chains = self['chainID'].unique()
846
+ for chain_id in unique_chains:
847
+ chain_data = self[self['chainID'] == chain_id]
848
+ chain_scene = Scene(chain_data, **self._meta)
849
+ output_filename = f"{base_filename}_{chain_id}.gro"
850
+ if verbose:
851
+ print(f"Writing chain '{chain_id}' to {output_filename}")
852
+ chain_scene.write_gro(output_filename, box_size=box_size, verbose=verbose)
853
+
854
+ # get methods
855
+ def get_coordinates(self):
856
+ return self[['x', 'y', 'z']]
857
+
858
+ def get_sequence(self):
859
+ pass
860
+
861
+ def set_coordinates(self, coordinates):
862
+ self[['x', 'y', 'z']] = coordinates
863
+
864
+ def copy(self, deep=True):
865
+ return Scene(super().copy(deep), **self._meta)
866
+
867
+ def correct_modified_aminoacids(self):
868
+ out = self.copy()
869
+ if 'modified_residues' in self._meta:
870
+ for i, row in out.modified_residues.iterrows():
871
+ sel = ((out['resName'] == row['resName']) &
872
+ (out['chainID'] == row['chainID']) &
873
+ (out['resSeq'] == row['resSeq']))
874
+ out.loc[sel, 'resName'] = row['stdRes']
875
+ return out
876
+
877
+ def rotate(self, rotation_matrix):
878
+ return self.dot(rotation_matrix)
879
+
880
+ def translate(self, other):
881
+ new = self.copy()
882
+ new.set_coordinates(self.get_coordinates() + other)
883
+ return new
884
+
885
+ def dot(self, other):
886
+ new = self.copy()
887
+ new.set_coordinates(self.get_coordinates().dot(other))
888
+ return new
889
+
890
+ def distance_map(self, threshold=None) -> Union[np.ndarray, tuple]:
891
+ """
892
+ Returns a distance map of the Scene.
893
+ If threshold is None, returns a dense n×n distance matrix.
894
+ If threshold is a float, returns a sparse representation of the distances
895
+ (row_idx, col_idx, dist_vals) for all pairs of atoms with distance ≤ threshold.
896
+ """
897
+ if threshold is None:
898
+ return self.distance_map_dense()
899
+ else:
900
+ return self.distance_map_sparse(threshold)
901
+
902
+ def distance_map_dense(self) -> np.ndarray:
903
+ """
904
+ Dense n×n distance matrix.
905
+ Equivalent to your original, but via pdist/squareform for speed.
906
+ """
907
+ coords = self.get_coordinates().to_numpy()
908
+ return distance.squareform(distance.pdist(coords))
909
+
910
+
911
+ def distance_map_sparse(self, threshold: float) -> Tuple[np.ndarray, np.ndarray]:
912
+ """
913
+ Fast, memory-light “sparse” distances ≤ threshold.
914
+ Returns:
915
+ - pairs: (M, 2) array of index pairs [i, j]
916
+ - dists: (M,) array of corresponding distances
917
+ """
918
+ if threshold is None:
919
+ raise ValueError("Must supply a threshold for sparse distance_map")
920
+
921
+ coords = self.get_coordinates().to_numpy()
922
+ tree = cKDTree(coords)
923
+ pairs = tree.query_pairs(threshold, output_type='ndarray') # shape (N, 2)
924
+
925
+ diffs = coords[pairs[:, 0]] - coords[pairs[:, 1]]
926
+ dists = np.linalg.norm(diffs, axis=1)
927
+
928
+ # symmetric pairs: stack (i,j) and (j,i) as rows
929
+ pairs_sym = np.vstack([pairs, pairs[:, ::-1]]) # shape (2N, 2)
930
+ dists_sym = np.tile(dists, 2)
931
+
932
+ return pairs_sym, dists_sym
933
+
934
+ def get_center(self) -> pandas.Series:
935
+ """
936
+ Compute the centroid (geometric center) of the atomic coordinates.
937
+
938
+ Returns
939
+ -------
940
+ pandas.Series
941
+ A Series with index ['x','y','z'] giving the mean of each coordinate.
942
+ """
943
+ # select the three coord columns and take their column‐wise mean
944
+ return self[['x','y','z']].mean()
945
+
946
+ def center(self) -> "Scene":
947
+ """
948
+ Return a new Scene with coordinates shifted so the centroid is at the origin.
949
+
950
+ Returns
951
+ -------
952
+ Scene
953
+ A new Scene object with centered coordinates.
954
+ """
955
+ ctr = self.get_center()
956
+ # make a shallow copy of metadata and DataFrame
957
+ out = self.copy(deep=True)
958
+ # subtract the centroid Series from each row (axis=1 => align on column names)
959
+ out[['x','y','z']] = out[['x','y','z']].sub(ctr, axis=1)
960
+ return out
961
+
962
+
963
+ def __repr__(self):
964
+ try:
965
+ return f'<Scene ({len(self)})>\n{super().__repr__()}'
966
+ except Exception:
967
+ return '<Scene (Uninitialized)>'
968
+
969
+ def __add__(self, other: Union["Scene", float, Sequence, pandas.Series]) -> "Scene":
970
+ if isinstance(other, Scene):
971
+ logging.debug("Scene + Scene: concatenation")
972
+ df = pandas.concat([self, other], ignore_index=True)
973
+ return Scene(df, **self._meta)
974
+
975
+ logging.debug(f"Scene + {type(other)}: translation")
976
+ delta = _as_delta(other).to_numpy() # shape (3,)
977
+ out = self.copy(deep=True)
978
+
979
+ if 'coordinate_frames' in self._meta:
980
+ logging.debug("Scene + vector: multi-frame translation")
981
+ frames = self.get_coordinate_frames()
982
+ new_frames = frames + delta[None, None, :]
983
+ out._meta['coordinate_frames'] = new_frames
984
+ out.set_coordinates(new_frames[0])
985
+
986
+ else:
987
+ logging.debug("Scene + vector: single-frame translation")
988
+ out[['x','y','z']] = out[['x','y','z']] + delta
989
+ return out
990
+
991
+ def __radd__(self, other):
992
+ logging.debug(f"{type(other)} + Scene: __radd__ called")
993
+ return self.__add__(other)
994
+
995
+ def __sub__(self, other: Union["Scene", float, Sequence, pandas.Series]) -> "Scene":
996
+ if isinstance(other, Scene):
997
+ logging.debug("Scene - Scene: remove atoms with matching atom_index")
998
+ mask = ~self['atom_index'].isin(other['atom_index'])
999
+ df = self.loc[mask].reset_index(drop=True)
1000
+ return Scene(df, **self._meta)
1001
+
1002
+ logging.debug(f"Scene - {type(other)}: translation by -delta")
1003
+ delta = _as_delta(other).to_numpy()
1004
+ out = self.copy(deep=True)
1005
+
1006
+ if 'coordinate_frames' in self._meta:
1007
+ logging.debug("Scene - vector: multi-frame translation")
1008
+ frames = self.get_coordinate_frames()
1009
+ new_frames = frames - delta[None, None, :]
1010
+ out._meta['coordinate_frames'] = new_frames
1011
+ out.set_coordinates(new_frames[0])
1012
+
1013
+ else:
1014
+ logging.debug("Scene - vector: single-frame translation")
1015
+ out[['x','y','z']] = out[['x','y','z']].to_numpy() - delta
1016
+ return out
1017
+
1018
+ def __rsub__(self, other: Union[float, Sequence, pandas.Series]):
1019
+ logging.debug(f"{type(other)} - Scene: elementwise subtraction")
1020
+ delta = _as_delta(other).to_numpy()
1021
+ out = self.copy(deep=True)
1022
+
1023
+ if 'coordinate_frames' in self._meta:
1024
+ logging.debug("vector - Scene: multi-frame")
1025
+ frames = self.get_coordinate_frames()
1026
+ new_frames = delta[None, None, :] - frames
1027
+ out._meta['coordinate_frames'] = new_frames
1028
+ out.set_coordinates(new_frames[0])
1029
+
1030
+ else:
1031
+ logging.debug("vector - Scene: single-frame")
1032
+ out[['x','y','z']] = delta - out[['x','y','z']].to_numpy()
1033
+ return out
1034
+
1035
+ def __mul__(self, other: Union[float, Sequence, pandas.Series]) -> "Scene":
1036
+ logging.debug(f"Scene * {type(other)}: scaling")
1037
+ factor = _as_delta(other).to_numpy()
1038
+ out = self.copy(deep=True)
1039
+
1040
+ if 'coordinate_frames' in self._meta:
1041
+ logging.debug("Scene * vector: multi-frame scaling")
1042
+ frames = self.get_coordinate_frames()
1043
+ new_frames = frames * factor[None, None, :]
1044
+ out._meta['coordinate_frames'] = new_frames
1045
+ out.set_coordinates(new_frames[0])
1046
+
1047
+ else:
1048
+ logging.debug("Scene * vector: single-frame scaling")
1049
+ out[['x','y','z']] = out[['x','y','z']].to_numpy() * factor
1050
+ return out
1051
+
1052
+ def __rmul__(self, other):
1053
+ return self.__mul__(other)
1054
+
1055
+ def __truediv__(self, other: Union[float, Sequence, pandas.Series]) -> "Scene":
1056
+ logging.debug(f"Scene / {type(other)}: division")
1057
+ divisor = _as_delta(other).to_numpy()
1058
+ out = self.copy(deep=True)
1059
+
1060
+ if 'coordinate_frames' in self._meta:
1061
+ logging.debug("Scene / vector: multi-frame division")
1062
+ frames = self.get_coordinate_frames()
1063
+ new_frames = frames / divisor[None, None, :]
1064
+ out._meta['coordinate_frames'] = new_frames
1065
+ out.set_coordinates(new_frames[0])
1066
+
1067
+ else:
1068
+ logging.debug("Scene / vector: single-frame division")
1069
+ out[['x','y','z']] = out[['x','y','z']].to_numpy() / divisor
1070
+
1071
+ return out
1072
+
1073
+ def __neg__(self) -> "Scene":
1074
+ logging.debug("Scene: negation/reflection")
1075
+ out = self.copy(deep=True)
1076
+
1077
+ if 'coordinate_frames' in self._meta:
1078
+ logging.debug("Scene: multi-frame negation")
1079
+ frames = self.get_coordinate_frames()
1080
+ new_frames = -frames
1081
+ out._meta['coordinate_frames'] = new_frames
1082
+ out.set_coordinates(new_frames[0])
1083
+ else:
1084
+
1085
+ logging.debug("Scene: single-frame negation")
1086
+ out[['x','y','z']] = -out[['x','y','z']].to_numpy()
1087
+ return out
1088
+
1089
+ @property
1090
+ def _constructor(self):
1091
+ # Check if the DataFrame contains all the required columns
1092
+ if all(col in self.columns for col in self._columns.keys()):
1093
+ return Scene
1094
+ else:
1095
+ logging.debug("Warning: Missing required columns. Returning a standard DataFrame.")
1096
+ logging.debug([col for col in self._columns.keys() if col not in self.columns])
1097
+ return pandas.DataFrame
1098
+
1099
+ # def __getattr__(self, attr):
1100
+ # if '_meta' in self.__dict__ and attr in self._meta:
1101
+ # return self._meta[attr]
1102
+ # elif attr in self.columns:
1103
+ # return self[attr]
1104
+ # else:
1105
+ # raise AttributeError(f"type object {str(self.__class__)} has no attribute {str(attr)}")
1106
+
1107
+ # def __getattr__(self, attr):
1108
+ # # Safely retrieve _meta without triggering __getattr__ again.
1109
+ # meta = object.__getattribute__(self, '_meta') if '_meta' in self.__dict__ else {}
1110
+
1111
+ # if attr in meta:
1112
+ # return meta[attr]
1113
+
1114
+ # # Use object.__getattribute__ to get columns without recursion.
1115
+ # cols = object.__getattribute__(self, 'columns')
1116
+ # if attr in cols:
1117
+ # return self[attr]
1118
+
1119
+ # raise AttributeError(f"{self.__class__.__name__} has no attribute {attr}")
1120
+
1121
+ def __getattribute__(self, name):
1122
+ """
1123
+ Override attribute lookup only to provide access to items stored in _meta.
1124
+ All normal attributes (including methods, and DataFrame properties like 'columns')
1125
+ are obtained via the standard mechanism.
1126
+ """
1127
+ try:
1128
+ return super().__getattribute__(name)
1129
+ except AttributeError:
1130
+ # If not found normally, check if it is stored in _meta.
1131
+ _meta = object.__getattribute__(self, '_meta') if '_meta' in self.__dict__ else {}
1132
+ if name in _meta:
1133
+ return _meta[name]
1134
+ raise AttributeError(f"{self.__class__.__name__} has no attribute {name}")
1135
+
1136
+ def __setattr__(self, attr, value):
1137
+ # Always set _meta normally.
1138
+ if attr == '_meta':
1139
+ super().__setattr__(attr, value)
1140
+ return
1141
+
1142
+ # If the attribute name is one of the DataFrame's columns, assign to that column.
1143
+ try:
1144
+ columns = super().__getattribute__('columns')
1145
+ except AttributeError:
1146
+ columns = None
1147
+
1148
+ if columns is not None and attr in columns:
1149
+ self[attr] = value
1150
+ return
1151
+
1152
+ # If it's a built-in DataFrame attribute, set it normally.
1153
+ if hasattr(pandas.DataFrame, attr):
1154
+ super().__setattr__(attr, value)
1155
+ else:
1156
+ # Otherwise, store it in _meta.
1157
+ self._meta[attr] = value
1158
+
1159
+ __array_priority__ = 1000 # Ensure Scene takes precedence in numpy ops
1160
+
1161
+ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
1162
+ """
1163
+ Handle numpy ufuncs like np.add, np.subtract, np.multiply, etc.
1164
+ Route to corresponding dunder methods.
1165
+ """
1166
+
1167
+ if method != "__call__":
1168
+ return NotImplemented
1169
+
1170
+ # Unpack inputs
1171
+ logging.debug(f"Scene.__array_ufunc__({ufunc}, {method}, {inputs})")
1172
+ if ufunc == np.add:
1173
+ a, b = inputs
1174
+ if isinstance(a, Scene):
1175
+ return a.__add__(b)
1176
+ elif isinstance(b, Scene):
1177
+ return b.__radd__(a)
1178
+ elif ufunc == np.subtract:
1179
+ a, b = inputs
1180
+ if isinstance(a, Scene):
1181
+ return a.__sub__(b)
1182
+ elif isinstance(b, Scene):
1183
+ return b.__rsub__(a)
1184
+ elif ufunc == np.multiply:
1185
+ a, b = inputs
1186
+ if isinstance(a, Scene):
1187
+ return a.__mul__(b)
1188
+ elif isinstance(b, Scene):
1189
+ return b.__rmul__(a)
1190
+ elif ufunc == np.true_divide:
1191
+ a, b = inputs
1192
+ if isinstance(a, Scene):
1193
+ return a.__truediv__(b)
1194
+ elif ufunc == np.negative:
1195
+ (a,) = inputs
1196
+ if isinstance(a, Scene):
1197
+ return a.__neg__()
1198
+
1199
+ return NotImplemented
1200
+
1201
+ # helpers outside the class
1202
+
1203
+ def _as_delta(other) -> pandas.Series:
1204
+ """
1205
+ Normalize a scalar, sequence of length-3, or Series
1206
+ into a pandas.Series indexed ['x','y','z'].
1207
+ """
1208
+ if isinstance(other, pandas.Series):
1209
+ # Check that the series has 'x', 'y', 'z' as index, and reorder if necessary
1210
+ if set(other.index) != {'x', 'y', 'z'}:
1211
+ raise ValueError(f"Series index must be ['x','y','z'], not {other.index}")
1212
+ # Reorder the series to match ['x','y','z']
1213
+ delta = other.reindex(['x','y','z']).astype(float)
1214
+ elif isinstance(other, (int, float)):
1215
+ delta = pandas.Series([other]*3, index=['x','y','z'], dtype=float)
1216
+ else:
1217
+ arr = np.asarray(other, float)
1218
+ if arr.shape == (3,):
1219
+ delta = pandas.Series(arr, index=['x','y','z'])
1220
+ else:
1221
+ raise ValueError(f"Cannot interpret {other!r} as a 3-vector")
1222
+ return delta
1223
+
1224
+ def _negate(other):
1225
+ """Invert a scalar/sequence/Series for subtraction."""
1226
+ delta = _as_delta(other)
1227
+ return -delta
1228
+
1229
+ if __name__ == '__main__':
1230
+ particles = pandas.DataFrame([[0, 0, 0],
1231
+ [0, 1, 0],
1232
+ [0, 0, 1]],
1233
+ columns=['x', 'y', 'z'])
1234
+ s = Scene(particles)
1235
+ s.write_pdb('test.pdb')
1236
+
1237
+ s = Scene.from_pdb('test.pdb')
1238
+
1239
+ s.write_cif('test.cif')
1240
+
1241
+ s = Scene.from_cif('test.cif')
1242
+
1243
+ s = Scene.from_fixPDB(pdbid='1JGE')
1244
+
1245
+ s1 = Scene(particles)
1246
+ s1.write_pdb('test.pdb')
1247
+ s2 = Scene.from_pdb('test.pdb')
1248
+ s2.write_cif('test.cif')
1249
+ s3 = Scene.from_cif('test.cif')
1250
+ s3.write_pdb('test2.pdb')
1251
+ s4 = Scene.from_pdb('test2.pdb')
1252
+
1253
+ s1.to_csv('particles_1.csv')
1254
+ s2.to_csv('particles_2.csv')
1255
+ s3.to_csv('particles_3.csv')
1256
+ s4.to_csv('particles_4.csv')
1257
+
1258
+ """
1259
+ import numpy as np
1260
+ import pandas as pd
1261
+
1262
+ def h5store(filename, df, **kwargs):
1263
+ store = pandas.HDFStore(filename)
1264
+ store.put('mydata', df)
1265
+ store.get_storer('mydata').attrs.metadata = kwargs
1266
+ store.close()
1267
+
1268
+ def h5load(store):
1269
+ data = store['mydata']
1270
+ metadata = store.get_storer('mydata').attrs.metadata
1271
+ return data, metadata
1272
+
1273
+ a = pandas.DataFrame(
1274
+ data=pandas.np.random.randint(0, 100, (10, 5)), columns=list('ABCED'))
1275
+
1276
+ filename = '/tmp/data.h5'
1277
+ metadata = dict(local_tz='US/Eastern')
1278
+ h5store(filename, a, **metadata)
1279
+ with pandas.HDFStore(filename) as store:
1280
+ data, metadata = h5load(store)
1281
+
1282
+ print(data)
1283
+ # A B C E D
1284
+ # 0 9 20 92 43 25
1285
+ # 1 2 64 54 0 63
1286
+ # 2 22 42 3 83 81
1287
+ # 3 3 71 17 64 53
1288
+ # 4 52 10 41 22 43
1289
+ # 5 48 85 96 72 88
1290
+ # 6 10 47 2 10 78
1291
+ # 7 30 80 3 59 16
1292
+ # 8 13 52 98 79 65
1293
+ # 9 6 93 55 40 3
1294
+
1295
+ $DATE$ $TIME$
1296
+ """