aidsorb 0.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aidsorb/__init__.py +34 -0
- aidsorb/_cli.py +46 -0
- aidsorb/_internal.py +73 -0
- aidsorb/data.py +455 -0
- aidsorb/datamodules.py +262 -0
- aidsorb/litmodels.py +179 -0
- aidsorb/models.py +124 -0
- aidsorb/modules.py +403 -0
- aidsorb/pkg_data/README.md +1 -0
- aidsorb/pkg_data/periodic_table.csv +119 -0
- aidsorb/transforms.py +262 -0
- aidsorb/utils.py +203 -0
- aidsorb/visualize.py +174 -0
- aidsorb-0.0.0.dist-info/LICENSE +674 -0
- aidsorb-0.0.0.dist-info/METADATA +120 -0
- aidsorb-0.0.0.dist-info/RECORD +19 -0
- aidsorb-0.0.0.dist-info/WHEEL +5 -0
- aidsorb-0.0.0.dist-info/entry_points.txt +3 -0
- aidsorb-0.0.0.dist-info/top_level.txt +1 -0
aidsorb/__init__.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
# This file is part of AIdsorb.
|
|
2
|
+
# Copyright (C) 2024 Antonios P. Sarikas
|
|
3
|
+
|
|
4
|
+
# AIdsorb is free software: you can redistribute it and/or modify
|
|
5
|
+
# it under the terms of the GNU General Public License as published by
|
|
6
|
+
# the Free Software Foundation, either version 3 of the License, or
|
|
7
|
+
# (at your option) any later version.
|
|
8
|
+
|
|
9
|
+
# This program is distributed in the hope that it will be useful,
|
|
10
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
11
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
12
|
+
# GNU General Public License for more details.
|
|
13
|
+
|
|
14
|
+
# You should have received a copy of the GNU General Public License
|
|
15
|
+
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
16
|
+
|
|
17
|
+
r"""
|
|
18
|
+
**AIdsorb** is a :fa:`python; fa-fade` Python package for **deep learning on
|
|
19
|
+
molecular point clouds**.
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
.. admonition:: AIdsorb adopts the following conventions
|
|
23
|
+
|
|
24
|
+
* A ``pcd`` is represented as a :class:`numpy.ndarray` of shape ``(N, 3+C)``.
|
|
25
|
+
* A molecular ``pcd`` is represented as a :class:`numpy.ndarray` of shape ``(N, 4+C)``
|
|
26
|
+
where ``N`` is the number of atoms, ``pcd[:, :3]`` are the **atomic
|
|
27
|
+
coordinates**, ``pcd[:, 3]`` are the **atomic numbers** and ``pcd[:, 4:]``
|
|
28
|
+
any **additional features**. If ``C == 0``, then the only features are the
|
|
29
|
+
atomic numbers.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
__author__ = 'Antonios P. Sarikas'
|
|
33
|
+
__copyright__ = 'Copyright (c) 2024 Antonios P. Sarikas'
|
|
34
|
+
__license__ = ' GPL-3.0-only'
|
aidsorb/_cli.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
# This file is part of AIdsorb.
|
|
2
|
+
# Copyright (C) 2024 Antonios P. Sarikas
|
|
3
|
+
|
|
4
|
+
# AIdsorb is free software: you can redistribute it and/or modify
|
|
5
|
+
# it under the terms of the GNU General Public License as published by
|
|
6
|
+
# the Free Software Foundation, either version 3 of the License, or
|
|
7
|
+
# (at your option) any later version.
|
|
8
|
+
|
|
9
|
+
# This program is distributed in the hope that it will be useful,
|
|
10
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
11
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
12
|
+
# GNU General Public License for more details.
|
|
13
|
+
|
|
14
|
+
# You should have received a copy of the GNU General Public License
|
|
15
|
+
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
16
|
+
|
|
17
|
+
r"""
|
|
18
|
+
This module provides helper functions for the CLI.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def lightning_cli():
|
|
23
|
+
r"""
|
|
24
|
+
CLI for the deep learning part.
|
|
25
|
+
"""
|
|
26
|
+
from lightning.pytorch.cli import LightningCLI
|
|
27
|
+
from . datamodules import PCDDataModule
|
|
28
|
+
from . litmodels import PointLit
|
|
29
|
+
|
|
30
|
+
LightningCLI(PointLit, PCDDataModule)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def aidsorb_fire():
|
|
34
|
+
r"""
|
|
35
|
+
CLI for creating, preparing and visualizing molecular point clouds.
|
|
36
|
+
"""
|
|
37
|
+
import fire
|
|
38
|
+
from . visualize import draw_pcd_from_file
|
|
39
|
+
from . utils import pcd_from_dir
|
|
40
|
+
from . data import prepare_data
|
|
41
|
+
|
|
42
|
+
fire.Fire({
|
|
43
|
+
'visualize': draw_pcd_from_file,
|
|
44
|
+
'create': pcd_from_dir,
|
|
45
|
+
'prepare': prepare_data,
|
|
46
|
+
})
|
aidsorb/_internal.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
# This file is part of AIdsorb.
|
|
2
|
+
# Copyright (C) 2024 Antonios P. Sarikas
|
|
3
|
+
|
|
4
|
+
# AIdsorb is free software: you can redistribute it and/or modify
|
|
5
|
+
# it under the terms of the GNU General Public License as published by
|
|
6
|
+
# the Free Software Foundation, either version 3 of the License, or
|
|
7
|
+
# (at your option) any later version.
|
|
8
|
+
|
|
9
|
+
# This program is distributed in the hope that it will be useful,
|
|
10
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
11
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
12
|
+
# GNU General Public License for more details.
|
|
13
|
+
|
|
14
|
+
# You should have received a copy of the GNU General Public License
|
|
15
|
+
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
16
|
+
|
|
17
|
+
r"""
|
|
18
|
+
This module provides helper functions and data for use in other modules.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
from importlib.resources import files
|
|
22
|
+
import pandas as pd
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _check_shape(array):
|
|
26
|
+
r"""
|
|
27
|
+
Check if ``array`` has valid shape to be considered a point cloud.
|
|
28
|
+
|
|
29
|
+
Parameters
|
|
30
|
+
----------
|
|
31
|
+
array
|
|
32
|
+
|
|
33
|
+
Raises
|
|
34
|
+
------
|
|
35
|
+
ValueError
|
|
36
|
+
If ``array.shape != (N, 3+C)``.
|
|
37
|
+
"""
|
|
38
|
+
if not ((array.ndim == 2) and (array.shape[1] >= 3)):
|
|
39
|
+
raise ValueError(
|
|
40
|
+
'Expecting array of shape (N, 3+C) '
|
|
41
|
+
f'but got array of shape {array.shape}!'
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _check_shape_vis(array):
|
|
46
|
+
r"""
|
|
47
|
+
Check if ``array`` has valid shape to be considered a molecular point cloud.
|
|
48
|
+
|
|
49
|
+
Parameters
|
|
50
|
+
----------
|
|
51
|
+
array
|
|
52
|
+
|
|
53
|
+
Raises
|
|
54
|
+
------
|
|
55
|
+
ValueError
|
|
56
|
+
If ``array.shape != (N, 4+C)``.
|
|
57
|
+
"""
|
|
58
|
+
if not ((array.ndim == 2) and (array.shape[1] >= 4)):
|
|
59
|
+
raise ValueError(
|
|
60
|
+
'Expecting array of shape (N, 4+C) '
|
|
61
|
+
f'but got array of shape {array.shape}!'
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
# Default value for controlling randomness.
|
|
66
|
+
_SEED = 1
|
|
67
|
+
|
|
68
|
+
# This will be the default on Pandas 3.0
|
|
69
|
+
pd.options.mode.copy_on_write = True
|
|
70
|
+
|
|
71
|
+
# Load the periodic table.
|
|
72
|
+
with files('aidsorb.pkg_data').joinpath('periodic_table.csv').open() as fhand:
|
|
73
|
+
_ptable = pd.read_csv(fhand, index_col='atomic_number')
|
aidsorb/data.py
ADDED
|
@@ -0,0 +1,455 @@
|
|
|
1
|
+
# This file is part of AIdsorb.
|
|
2
|
+
# Copyright (C) 2024 Antonios P. Sarikas
|
|
3
|
+
|
|
4
|
+
# AIdsorb is free software: you can redistribute it and/or modify
|
|
5
|
+
# it under the terms of the GNU General Public License as published by
|
|
6
|
+
# the Free Software Foundation, either version 3 of the License, or
|
|
7
|
+
# (at your option) any later version.
|
|
8
|
+
|
|
9
|
+
# This program is distributed in the hope that it will be useful,
|
|
10
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
11
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
12
|
+
# GNU General Public License for more details.
|
|
13
|
+
|
|
14
|
+
# You should have received a copy of the GNU General Public License
|
|
15
|
+
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
16
|
+
|
|
17
|
+
r"""
|
|
18
|
+
This module provides helper functions and classes for creating datasets and
|
|
19
|
+
handling point clouds of variable sizes.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
import os
|
|
23
|
+
import json
|
|
24
|
+
from pathlib import Path
|
|
25
|
+
from typing import Sequence
|
|
26
|
+
import numpy as np
|
|
27
|
+
import torch
|
|
28
|
+
from torch.utils.data import random_split, Dataset
|
|
29
|
+
from torch.nn.utils.rnn import pad_sequence
|
|
30
|
+
from . _internal import _SEED, pd
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def prepare_data(source: str, split_ratio: Sequence=(0.8, 0.1, 0.1), seed: int=_SEED):
|
|
34
|
+
r"""
|
|
35
|
+
Split a source of point clouds in train, validation and test sets.
|
|
36
|
+
|
|
37
|
+
Each ``.json`` file that is created, stores the names of the point clouds
|
|
38
|
+
that will be used for *training*, *validation* and *testing*.
|
|
39
|
+
|
|
40
|
+
.. warning::
|
|
41
|
+
* No directory is created by :func:`prepare_data`. All ``.json`` files
|
|
42
|
+
are stored under the directory containing ``source``.
|
|
43
|
+
* Splitting doesn't support stratification. If your dataset is small and
|
|
44
|
+
you want to perform classification, consider using
|
|
45
|
+
`train_test_split`_.
|
|
46
|
+
|
|
47
|
+
.. _train_test_split: https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html
|
|
48
|
+
|
|
49
|
+
Parameters
|
|
50
|
+
----------
|
|
51
|
+
source : str
|
|
52
|
+
Absolute or relative path to the file holding the point clouds.
|
|
53
|
+
split_ratio : sequence, default=(0.8, 0.1, 0.1)
|
|
54
|
+
The sizes or fractions of splits to be produced.
|
|
55
|
+
|
|
56
|
+
* ``split_ratio[0] == train``.
|
|
57
|
+
* ``split_ratio[1] == validation``.
|
|
58
|
+
* ``split_ratio[2] == test``.
|
|
59
|
+
|
|
60
|
+
seed : int, default=1
|
|
61
|
+
Controls the randomness of the ``rng`` used for splitting.
|
|
62
|
+
|
|
63
|
+
Examples
|
|
64
|
+
--------
|
|
65
|
+
Before the split::
|
|
66
|
+
|
|
67
|
+
pcd_data
|
|
68
|
+
└──source.npz
|
|
69
|
+
|
|
70
|
+
>>> prepare_data('path/to/pcd_data/source.npz') # doctest: +SKIP
|
|
71
|
+
|
|
72
|
+
After the split::
|
|
73
|
+
|
|
74
|
+
pcd_data
|
|
75
|
+
├──source.npz
|
|
76
|
+
├──train.json
|
|
77
|
+
├──validation.json
|
|
78
|
+
└──test.json
|
|
79
|
+
"""
|
|
80
|
+
rng = torch.Generator().manual_seed(seed)
|
|
81
|
+
path = Path(source).parent
|
|
82
|
+
pcds = np.load(source)
|
|
83
|
+
|
|
84
|
+
train, val, test = random_split(pcds.files, split_ratio, generator=rng)
|
|
85
|
+
|
|
86
|
+
for split, mode in zip((train, val, test), ('train', 'validation', 'test')):
|
|
87
|
+
names = list(split)
|
|
88
|
+
with open(os.path.join(path, f'{mode}.json'), 'w') as fhand:
|
|
89
|
+
json.dump(names, fhand, indent=4)
|
|
90
|
+
|
|
91
|
+
print('\033[32mData preparation completed!\033[0m')
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def get_names(filename):
|
|
95
|
+
r"""
|
|
96
|
+
Return names stored in a ``.json`` file.
|
|
97
|
+
|
|
98
|
+
Parameters
|
|
99
|
+
----------
|
|
100
|
+
filename : str
|
|
101
|
+
The name of the file from which names will be retrieved.
|
|
102
|
+
|
|
103
|
+
Returns
|
|
104
|
+
-------
|
|
105
|
+
names : list
|
|
106
|
+
"""
|
|
107
|
+
with open(filename, 'r') as fhand:
|
|
108
|
+
names = json.load(fhand)
|
|
109
|
+
|
|
110
|
+
return names
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def upsample_pcd(pcd, size):
|
|
114
|
+
r"""
|
|
115
|
+
Upsample ``pcd`` to a new ``size`` by sampling with replacement from ``pcd``.
|
|
116
|
+
|
|
117
|
+
Parameters
|
|
118
|
+
----------
|
|
119
|
+
pcd : tensor of shape (N, C)
|
|
120
|
+
The original point cloud of size ``N``.
|
|
121
|
+
size : int
|
|
122
|
+
The size of the new point cloud.
|
|
123
|
+
|
|
124
|
+
Returns
|
|
125
|
+
-------
|
|
126
|
+
new_pcd : tensor of shape (size, C).
|
|
127
|
+
|
|
128
|
+
Examples
|
|
129
|
+
--------
|
|
130
|
+
>>> pcd = torch.tensor([[2, 4, 5, 6]])
|
|
131
|
+
>>> upsample_pcd(pcd, 3)
|
|
132
|
+
tensor([[2, 4, 5, 6],
|
|
133
|
+
[2, 4, 5, 6],
|
|
134
|
+
[2, 4, 5, 6]])
|
|
135
|
+
|
|
136
|
+
>>> # New points point must be from pcd.
|
|
137
|
+
>>> pcd = torch.randn(10, 4)
|
|
138
|
+
>>> new_pcd = upsample_pcd(pcd, 20)
|
|
139
|
+
>>> (new_pcd[-1] == pcd).all(1).any() # Check for last point.
|
|
140
|
+
tensor(True)
|
|
141
|
+
|
|
142
|
+
>>> # No upsampling.
|
|
143
|
+
>>> pcd = torch.randn(100, 4)
|
|
144
|
+
>>> new_pcd = upsample_pcd(pcd, len(pcd))
|
|
145
|
+
>>> torch.equal(pcd, new_pcd)
|
|
146
|
+
True
|
|
147
|
+
|
|
148
|
+
"""
|
|
149
|
+
n_samples = size - len(pcd)
|
|
150
|
+
indices = torch.from_numpy(np.random.choice(len(pcd), n_samples, replace=True))
|
|
151
|
+
new_points = pcd[indices]
|
|
152
|
+
|
|
153
|
+
return torch.cat((pcd, new_points))
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def pad_pcds(pcds, channels_first=True, mode='upsample'):
|
|
157
|
+
r"""
|
|
158
|
+
Pad a sequence of variable size point clouds.
|
|
159
|
+
|
|
160
|
+
Each point cloud must have shape ``(N_i, C)``.
|
|
161
|
+
|
|
162
|
+
Parameters
|
|
163
|
+
----------
|
|
164
|
+
pcds : sequence of tensors
|
|
165
|
+
mode : {'zeropad', 'upsample'}, default='upsample'
|
|
166
|
+
channels_first : bool, default=True
|
|
167
|
+
|
|
168
|
+
Returns
|
|
169
|
+
-------
|
|
170
|
+
batch : tensor of shape (B, T, C) or (B, C, T)
|
|
171
|
+
If ``channels_first=False``, then ``batch`` has shape ``(B, T, C)``,
|
|
172
|
+
where ``B == len(pcds)`` is the batch size and ``T`` is the size of
|
|
173
|
+
the largest point cloud in ``pcds``. Otherwise, ``(B, C, T)``.
|
|
174
|
+
|
|
175
|
+
See Also
|
|
176
|
+
--------
|
|
177
|
+
:func:`upsample_pcd` : For a description of ``'upsample'`` mode.
|
|
178
|
+
:func:`torch.nn.utils.rnn.pad_sequence` : For a description of ``'zeropad'`` mode.
|
|
179
|
+
|
|
180
|
+
Examples
|
|
181
|
+
--------
|
|
182
|
+
>>> x1 = torch.tensor([[1, 2, 3, 4]])
|
|
183
|
+
>>> x2 = torch.tensor([[2, 5, 3, 8], [0, 2, 8, 9]])
|
|
184
|
+
|
|
185
|
+
>>> batch = pad_pcds((x1, x2), channels_first=False)
|
|
186
|
+
>>> batch
|
|
187
|
+
tensor([[[1, 2, 3, 4],
|
|
188
|
+
[1, 2, 3, 4]],
|
|
189
|
+
<BLANKLINE>
|
|
190
|
+
[[2, 5, 3, 8],
|
|
191
|
+
[0, 2, 8, 9]]])
|
|
192
|
+
|
|
193
|
+
>>> batch = pad_pcds((x1, x2), channels_first=True)
|
|
194
|
+
>>> batch
|
|
195
|
+
tensor([[[1, 1],
|
|
196
|
+
[2, 2],
|
|
197
|
+
[3, 3],
|
|
198
|
+
[4, 4]],
|
|
199
|
+
<BLANKLINE>
|
|
200
|
+
[[2, 0],
|
|
201
|
+
[5, 2],
|
|
202
|
+
[3, 8],
|
|
203
|
+
[8, 9]]])
|
|
204
|
+
|
|
205
|
+
>>> batch = pad_pcds((x1, x2), channels_first=False, mode='zeropad')
|
|
206
|
+
>>> batch
|
|
207
|
+
tensor([[[1, 2, 3, 4],
|
|
208
|
+
[0, 0, 0, 0]],
|
|
209
|
+
<BLANKLINE>
|
|
210
|
+
[[2, 5, 3, 8],
|
|
211
|
+
[0, 2, 8, 9]]])
|
|
212
|
+
|
|
213
|
+
>>> batch = pad_pcds((x1, x2), channels_first=True, mode='zeropad')
|
|
214
|
+
>>> batch
|
|
215
|
+
tensor([[[1, 0],
|
|
216
|
+
[2, 0],
|
|
217
|
+
[3, 0],
|
|
218
|
+
[4, 0]],
|
|
219
|
+
<BLANKLINE>
|
|
220
|
+
[[2, 0],
|
|
221
|
+
[5, 2],
|
|
222
|
+
[3, 8],
|
|
223
|
+
[8, 9]]])
|
|
224
|
+
"""
|
|
225
|
+
if mode == 'zeropad':
|
|
226
|
+
batch = pad_sequence(pcds, batch_first=True, padding_value=0)
|
|
227
|
+
|
|
228
|
+
elif mode == 'upsample':
|
|
229
|
+
max_len = max(len(i) for i in pcds)
|
|
230
|
+
new_pcds = [upsample_pcd(p, max_len) if len(p) < max_len else p for p in pcds]
|
|
231
|
+
batch = torch.stack(new_pcds)
|
|
232
|
+
|
|
233
|
+
# Shape (B, n_points, C).
|
|
234
|
+
if channels_first:
|
|
235
|
+
batch = batch.transpose(1, 2) # Shape (B, C, n_points).
|
|
236
|
+
|
|
237
|
+
return batch
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
class Collator():
|
|
241
|
+
r"""
|
|
242
|
+
Collate a sequence of samples into a ``batch``.
|
|
243
|
+
|
|
244
|
+
Point clouds are padded before collation, so they can form a batch.
|
|
245
|
+
|
|
246
|
+
.. rubric:: Shapes
|
|
247
|
+
|
|
248
|
+
* Input: sequence of samples
|
|
249
|
+
|
|
250
|
+
Each sample is a tuple of tensors ``(pcd, label)``, where
|
|
251
|
+
``pcd`` has shape ``(N_i, C)`` and ``label`` has shape
|
|
252
|
+
``(n_outputs,)`` or ``()``.
|
|
253
|
+
|
|
254
|
+
* Output: tuple of length 2
|
|
255
|
+
|
|
256
|
+
* ``batch[0] == x`` with shape ``(B, C, T)`` if ``channels_first=True``,
|
|
257
|
+
otherwise ``(B, T, C)``. ``B`` is the batch size and ``T`` is the size
|
|
258
|
+
of the largest point cloud in the sequence.
|
|
259
|
+
* ``batch[1] == y`` with shape ``(B, n_outputs)`` or ``(B,)``.
|
|
260
|
+
|
|
261
|
+
.. tip::
|
|
262
|
+
Use an instance of this class as ``collate_fn`` with
|
|
263
|
+
``channels_first=True``, if your model is :class:`~aidsorb.models.PointNet`.
|
|
264
|
+
|
|
265
|
+
.. todo::
|
|
266
|
+
Add functionality for collating only point clouds (useful when the
|
|
267
|
+
dataset is unlabeled).
|
|
268
|
+
|
|
269
|
+
Parameters
|
|
270
|
+
----------
|
|
271
|
+
channels_first : bool, default=True
|
|
272
|
+
mode : {'zeropad', 'upsample'}, default='upsample'
|
|
273
|
+
|
|
274
|
+
See Also
|
|
275
|
+
--------
|
|
276
|
+
:func:`pad_pcds` : For a description of the parameters.
|
|
277
|
+
:func:`upsample_pcd` : For a description of the parameters.
|
|
278
|
+
|
|
279
|
+
Examples
|
|
280
|
+
--------
|
|
281
|
+
>>> sample1 = (torch.tensor([[1, 4, 5, 2]]), torch.tensor([1., 2.]))
|
|
282
|
+
>>> sample2 = (torch.tensor([[0, 4, 0, 2], [2, 4, 1, 8]]), torch.tensor([7., 3.]))
|
|
283
|
+
|
|
284
|
+
>>> collate_fn = Collator()
|
|
285
|
+
>>> x, y = collate_fn((sample1, sample2))
|
|
286
|
+
>>> x.shape
|
|
287
|
+
torch.Size([2, 4, 2])
|
|
288
|
+
>>> y.shape
|
|
289
|
+
torch.Size([2, 2])
|
|
290
|
+
>>> x
|
|
291
|
+
tensor([[[1, 1],
|
|
292
|
+
[4, 4],
|
|
293
|
+
[5, 5],
|
|
294
|
+
[2, 2]],
|
|
295
|
+
<BLANKLINE>
|
|
296
|
+
[[0, 2],
|
|
297
|
+
[4, 4],
|
|
298
|
+
[0, 1],
|
|
299
|
+
[2, 8]]])
|
|
300
|
+
>>> y
|
|
301
|
+
tensor([[1., 2.],
|
|
302
|
+
[7., 3.]])
|
|
303
|
+
|
|
304
|
+
>>> collate_fn = Collator(channels_first=False, mode='zeropad')
|
|
305
|
+
>>> x, y = collate_fn((sample1, sample2))
|
|
306
|
+
>>> x
|
|
307
|
+
tensor([[[1, 4, 5, 2],
|
|
308
|
+
[0, 0, 0, 0]],
|
|
309
|
+
<BLANKLINE>
|
|
310
|
+
[[0, 4, 0, 2],
|
|
311
|
+
[2, 4, 1, 8]]])
|
|
312
|
+
>>> y
|
|
313
|
+
tensor([[1., 2.],
|
|
314
|
+
[7., 3.]])
|
|
315
|
+
|
|
316
|
+
>>> # Label has shape (), i.e. is scalar.
|
|
317
|
+
>>> sample1 = (torch.tensor([[3, 4, 3, 2]]), torch.tensor(0))
|
|
318
|
+
>>> sample2 = (torch.tensor([[2, 4, 8, 2], [9, 4, 1, 8]]), torch.tensor(1))
|
|
319
|
+
>>> collate_fn = Collator(channels_first=False, mode='zeropad')
|
|
320
|
+
>>> x, y = collate_fn((sample1, sample2))
|
|
321
|
+
>>> x
|
|
322
|
+
tensor([[[3, 4, 3, 2],
|
|
323
|
+
[0, 0, 0, 0]],
|
|
324
|
+
<BLANKLINE>
|
|
325
|
+
[[2, 4, 8, 2],
|
|
326
|
+
[9, 4, 1, 8]]])
|
|
327
|
+
>>> y
|
|
328
|
+
tensor([0, 1])
|
|
329
|
+
"""
|
|
330
|
+
def __init__(self, channels_first=True, mode='upsample'):
|
|
331
|
+
self.channels_first = channels_first
|
|
332
|
+
self.mode = mode
|
|
333
|
+
|
|
334
|
+
def __call__(self, samples):
|
|
335
|
+
r"""
|
|
336
|
+
Parameters
|
|
337
|
+
----------
|
|
338
|
+
samples : sequence of tuples
|
|
339
|
+
Each sample is a tuple of tensors ``(pcd, label)`` where
|
|
340
|
+
``pcd.shape == (n_points, C)`` and ``label`` has shape
|
|
341
|
+
``(n_outputs,)`` or ``()``.
|
|
342
|
+
|
|
343
|
+
Returns
|
|
344
|
+
-------
|
|
345
|
+
batch : tuple of length 2
|
|
346
|
+
* ``batch[0] == x`` with shape ``(B, C, T)`` or ``(B, T, C)``, where
|
|
347
|
+
``T`` is the size of the largest point cloud.
|
|
348
|
+
* ``batch[1] == y`` with shape ``(B, n_outputs)`` or ``(B,)``.
|
|
349
|
+
"""
|
|
350
|
+
pcds, labels = list(zip(*samples))
|
|
351
|
+
|
|
352
|
+
x = pad_pcds(pcds, channels_first=self.channels_first, mode=self.mode)
|
|
353
|
+
y = torch.stack(labels)
|
|
354
|
+
|
|
355
|
+
return x, y
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
class PCDDataset(Dataset):
|
|
359
|
+
r"""
|
|
360
|
+
``Dataset`` for point clouds.
|
|
361
|
+
|
|
362
|
+
.. tip::
|
|
363
|
+
For implementing your own transforms, have a look at the transforms
|
|
364
|
+
`tutorial`_. For more flexibility, consider implementing them as
|
|
365
|
+
callable instances of classes.
|
|
366
|
+
|
|
367
|
+
.. _tutorial: https://pytorch.org/tutorials/beginner/data_loading_tutorial.html#transforms
|
|
368
|
+
|
|
369
|
+
Parameters
|
|
370
|
+
----------
|
|
371
|
+
pcd_names : list
|
|
372
|
+
List containing the names of the point clouds.
|
|
373
|
+
path_to_X : str
|
|
374
|
+
Absolute or relative path to the ``.npz`` file holding the point clouds.
|
|
375
|
+
path_to_Y : str, optional
|
|
376
|
+
Absolute or relative path to the ``.csv`` file holding the labels of the
|
|
377
|
+
point clouds.
|
|
378
|
+
|
|
379
|
+
.. warning::
|
|
380
|
+
The comma ``,`` is assumed as the field separator.
|
|
381
|
+
|
|
382
|
+
index_col : str, optional
|
|
383
|
+
Column name of the ``.csv`` file to be used as row labels. The names
|
|
384
|
+
(values) under this column must follow the same naming scheme as in
|
|
385
|
+
``pcd_names``.
|
|
386
|
+
labels : list, optional
|
|
387
|
+
List containing the names of the properties to be predicted. No effect
|
|
388
|
+
if ``path_to_Y=None``.
|
|
389
|
+
transform_x : callable, optional
|
|
390
|
+
Transforms applied to ``input``, i.e to each point cloud.
|
|
391
|
+
transform_y : callable, optional
|
|
392
|
+
Transforms applied to ``output``. No effect if ``path_to_Y=None``.
|
|
393
|
+
|
|
394
|
+
See Also
|
|
395
|
+
--------
|
|
396
|
+
:mod:`aidsorb.transforms` : For available point cloud transformations.
|
|
397
|
+
"""
|
|
398
|
+
def __init__(
|
|
399
|
+
self, pcd_names, path_to_X,
|
|
400
|
+
path_to_Y=None, index_col=None, labels=None,
|
|
401
|
+
transform_x=None, transform_y=None,
|
|
402
|
+
):
|
|
403
|
+
|
|
404
|
+
if (labels is not None) and (type(labels) != list):
|
|
405
|
+
raise ValueError('labels must be a list!')
|
|
406
|
+
|
|
407
|
+
self._pcd_names = pcd_names
|
|
408
|
+
self.path_to_X = path_to_X
|
|
409
|
+
self.path_to_Y = path_to_Y
|
|
410
|
+
self.labels = labels
|
|
411
|
+
self.index_col = index_col
|
|
412
|
+
self.transform_x = transform_x
|
|
413
|
+
self.transform_y = transform_y
|
|
414
|
+
|
|
415
|
+
self.X = None
|
|
416
|
+
self.Y = None
|
|
417
|
+
|
|
418
|
+
@property
|
|
419
|
+
def pcd_names(self):
|
|
420
|
+
r"""The names of the point clouds."""
|
|
421
|
+
return self._pcd_names
|
|
422
|
+
|
|
423
|
+
def __len__(self):
|
|
424
|
+
return len(self.pcd_names)
|
|
425
|
+
|
|
426
|
+
def __getitem__(self, idx):
|
|
427
|
+
# Account for np.load and multiprocessing.
|
|
428
|
+
if self.X is None:
|
|
429
|
+
self.X = np.load(self.path_to_X)
|
|
430
|
+
if self.Y is None and self.path_to_Y is not None:
|
|
431
|
+
self.Y = pd.read_csv(
|
|
432
|
+
self.path_to_Y,
|
|
433
|
+
index_col=self.index_col,
|
|
434
|
+
usecols=[*self.labels, self.index_col],
|
|
435
|
+
)
|
|
436
|
+
|
|
437
|
+
name = self.pcd_names[idx]
|
|
438
|
+
sample_x = self.X[name]
|
|
439
|
+
|
|
440
|
+
if self.transform_x is not None:
|
|
441
|
+
sample_x = self.transform_x(sample_x)
|
|
442
|
+
|
|
443
|
+
# Only for labeled datasets.
|
|
444
|
+
if self.Y is not None:
|
|
445
|
+
sample_y = self.Y.loc[name].to_numpy()
|
|
446
|
+
|
|
447
|
+
if self.transform_y is not None:
|
|
448
|
+
sample_y = self.transform_y(sample_y)
|
|
449
|
+
|
|
450
|
+
return (
|
|
451
|
+
torch.tensor(sample_x, dtype=torch.float),
|
|
452
|
+
torch.tensor(sample_y, dtype=torch.float)
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
return torch.tensor(sample_x, dtype=torch.float)
|