ocnn 2.2.6__py3-none-any.whl → 2.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ocnn/octree/points.py CHANGED
@@ -1,322 +1,323 @@
1
- # --------------------------------------------------------
2
- # Octree-based Sparse Convolutional Neural Networks
3
- # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
- # Licensed under The MIT License [see LICENSE for details]
5
- # Written by Peng-Shuai Wang
6
- # --------------------------------------------------------
7
-
8
- import torch
9
- import numpy as np
10
- from typing import Optional, Union, List
11
-
12
-
13
- class Points:
14
- r''' Represents a point cloud and contains some elementary transformations.
15
-
16
- Args:
17
- points (torch.Tensor): The coordinates of the points with a shape of
18
- :obj:`(N, 3)`, where :obj:`N` is the number of points.
19
- normals (torch.Tensor or None): The point normals with a shape of
20
- :obj:`(N, 3)`.
21
- features (torch.Tensor or None): The point features with a shape of
22
- :obj:`(N, C)`, where :obj:`C` is the channel of features.
23
- labels (torch.Tensor or None): The point labels with a shape of
24
- :obj:`(N, K)`, where :obj:`K` is the channel of labels.
25
- batch_id (torch.Tensor or None): The batch indices for each point with a
26
- shape of :obj:`(N, 1)`.
27
- batch_size (int): The batch size.
28
- '''
29
-
30
- def __init__(self, points: torch.Tensor,
31
- normals: Optional[torch.Tensor] = None,
32
- features: Optional[torch.Tensor] = None,
33
- labels: Optional[torch.Tensor] = None,
34
- batch_id: Optional[torch.Tensor] = None,
35
- batch_size: int = 1):
36
- super().__init__()
37
- self.points = points
38
- self.normals = normals
39
- self.features = features
40
- self.labels = labels
41
- self.batch_id = batch_id
42
- self.batch_size = batch_size
43
- self.device = points.device
44
- self.batch_npt = None # valid after `merge_points`
45
- self.check_input()
46
-
47
- def check_input(self):
48
- r''' Checks the input arguments.
49
- '''
50
-
51
- assert self.points.dim() == 2 and self.points.size(1) == 3
52
- if self.normals is not None:
53
- assert self.normals.dim() == 2 and self.normals.size(1) == 3
54
- assert self.normals.size(0) == self.points.size(0)
55
- if self.features is not None:
56
- assert self.features.dim() == 2
57
- assert self.features.size(0) == self.points.size(0)
58
- if self.labels is not None:
59
- assert self.labels.dim() == 2 or self.labels.dim() == 1
60
- assert self.labels.size(0) == self.points.size(0)
61
- if self.labels.dim() == 1:
62
- self.labels = self.labels.unsqueeze(1)
63
- if self.batch_id is not None:
64
- assert self.batch_id.dim() == 2 or self.batch_id.dim() == 1
65
- assert self.batch_id.size(0) == self.points.size(0)
66
- if self.batch_id.dim() == 1:
67
- self.batch_id = self.batch_id.unsqueeze(1)
68
- assert self.batch_id.size(1) == 1
69
-
70
- @property
71
- def npt(self):
72
- return self.points.shape[0]
73
-
74
- def orient_normal(self, axis: str = 'x'):
75
- r''' Orients the point normals along a given axis.
76
-
77
- Args:
78
- axis (int): The coordinate axes, choose from :obj:`x`, :obj:`y` and
79
- :obj:`z`. (default: :obj:`x`)
80
- '''
81
-
82
- if self.normals is None:
83
- return
84
-
85
- axis_map = {'x': 0, 'y': 1, 'z': 2, 'xyz': 3}
86
- idx = axis_map[axis]
87
- if idx < 3:
88
- flags = self.normals[:, idx] > 0
89
- flags = flags.float() * 2.0 - 1.0 # [0, 1] -> [-1, 1]
90
- self.normals = self.normals * flags.unsqueeze(1)
91
- else:
92
- self.normals.abs_()
93
-
94
- def scale(self, factor: torch.Tensor):
95
- r''' Rescales the point cloud.
96
-
97
- Args:
98
- factor (torch.Tensor): The scale factor with shape :obj:`(3,)`.
99
- '''
100
-
101
- non_zero = (factor != 0).all()
102
- all_ones = (factor == 1.0).all()
103
- non_uniform = (factor != factor[0]).any()
104
- assert non_zero, 'The scale factor must not constain 0.'
105
- if all_ones: return
106
-
107
- factor = factor.to(self.device)
108
- self.points = self.points * factor
109
- if self.normals is not None and non_uniform:
110
- ifactor = 1.0 / factor
111
- self.normals = self.normals * ifactor
112
- norm2 = torch.sqrt(torch.sum(self.normals ** 2, dim=1, keepdim=True))
113
- self.normals = self.normals / torch.clamp(norm2, min=1.0e-12)
114
-
115
- def rotate(self, angle: torch.Tensor):
116
- r''' Rotates the point cloud.
117
-
118
- Args:
119
- angle (torch.Tensor): The rotation angles in radian with shape :obj:`(3,)`.
120
- '''
121
-
122
- cos, sin = angle.cos(), angle.sin()
123
- # rotx, roty, rotz are actually the transpose of the rotation matrices
124
- rotx = torch.Tensor([[1, 0, 0], [0, cos[0], sin[0]], [0, -sin[0], cos[0]]])
125
- roty = torch.Tensor([[cos[1], 0, -sin[1]], [0, 1, 0], [sin[1], 0, cos[1]]])
126
- rotz = torch.Tensor([[cos[2], sin[2], 0], [-sin[2], cos[2], 0], [0, 0, 1]])
127
- rot = rotx @ roty @ rotz
128
-
129
- rot = rot.to(self.device)
130
- self.points = self.points @ rot
131
- if self.normals is not None:
132
- self.normals = self.normals @ rot
133
-
134
- def translate(self, dis: torch.Tensor):
135
- r''' Translates the point cloud.
136
-
137
- Args:
138
- dis (torch.Tensor): The displacement with shape :obj:`(3,)`.
139
- '''
140
-
141
- dis = dis.to(self.device)
142
- self.points = self.points + dis
143
-
144
- def flip(self, axis: str):
145
- r''' Flips the point cloud along the given :attr:`axis`.
146
-
147
- Args:
148
- axis (str): The flipping axis, choosen from :obj:`x`, :obj:`y`, and :obj`z`.
149
- '''
150
-
151
- axis_map = {'x': 0, 'y': 1, 'z': 2}
152
- for x in axis:
153
- idx = axis_map[x]
154
- self.points[:, idx] *= -1.0
155
- if self.normals is not None:
156
- self.normals[:, idx] *= -1.0
157
-
158
- def clip(self, min: float = -1.0, max: float = 1.0, esp: float = 0.01):
159
- r''' Clips the point cloud to :obj:`[min+esp, max-esp]` and returns the mask.
160
-
161
- Args:
162
- min (float): The minimum value to clip.
163
- max (float): The maximum value to clip.
164
- esp (float): The margin.
165
- '''
166
-
167
- mask = self.inbox_mask(min + esp, max - esp)
168
- tmp = self.__getitem__(mask)
169
- self.__dict__.update(tmp.__dict__)
170
- return mask
171
-
172
- def __getitem__(self, mask: torch.Tensor):
173
- r''' Slices the point cloud according a given :attr:`mask`.
174
- '''
175
-
176
- dummy_pts = torch.zeros(1, 3, device=self.device)
177
- out = Points(dummy_pts, batch_size=self.batch_size)
178
-
179
- out.points = self.points[mask]
180
- if self.normals is not None:
181
- out.normals = self.normals[mask]
182
- if self.features is not None:
183
- out.features = self.features[mask]
184
- if self.labels is not None:
185
- out.labels = self.labels[mask]
186
- if self.batch_id is not None:
187
- out.batch_id = self.batch_id[mask]
188
- return out
189
-
190
- def inbox_mask(self, bbmin: Union[float, torch.Tensor] = -1.0,
191
- bbmax: Union[float, torch.Tensor] = 1.0):
192
- r''' Returns a mask indicating whether the points are within the specified
193
- bounding box or not.
194
- '''
195
-
196
- mask_min = torch.all(self.points > bbmin, dim=1)
197
- mask_max = torch.all(self.points < bbmax, dim=1)
198
- mask = torch.logical_and(mask_min, mask_max)
199
- return mask
200
-
201
- def bbox(self):
202
- r''' Returns the bounding box.
203
- '''
204
-
205
- # torch.min and torch.max return (value, indices)
206
- bbmin = self.points.min(dim=0)
207
- bbmax = self.points.max(dim=0)
208
- return bbmin[0], bbmax[0]
209
-
210
- def normalize(self, bbmin: torch.Tensor, bbmax: torch.Tensor,
211
- scale: float = 1.0):
212
- r''' Normalizes the point cloud to :obj:`[-scale, scale]`.
213
-
214
- Args:
215
- bbmin (torch.Tensor): The minimum coordinates of the bounding box.
216
- bbmax (torch.Tensor): The maximum coordinates of the bounding box.
217
- scale (float): The scale factor
218
- '''
219
-
220
- center = (bbmin + bbmax) * 0.5
221
- box_size = (bbmax - bbmin).max() + 1.0e-6
222
- self.points = (self.points - center) * (2.0 * scale / box_size)
223
-
224
- def to(self, device: Union[torch.device, str], non_blocking: bool = False):
225
- r''' Moves the Points to a specified device.
226
-
227
- Args:
228
- device (torch.device or str): The destination device.
229
- non_blocking (bool): If True and the source is in pinned memory, the copy
230
- will be asynchronous with respect to the host. Otherwise, the argument
231
- has no effect. Default: False.
232
- '''
233
-
234
- if isinstance(device, str):
235
- device = torch.device(device)
236
-
237
- # If on the save device, directly retrun self
238
- if self.device == device:
239
- return self
240
-
241
- # Construct a new Points on the specified device
242
- points = Points(torch.zeros(1, 3, device=device))
243
- points.batch_npt = self.batch_npt
244
- points.points = self.points.to(device, non_blocking=non_blocking)
245
- if self.normals is not None:
246
- points.normals = self.normals.to(device, non_blocking=non_blocking)
247
- if self.features is not None:
248
- points.features = self.features.to(device, non_blocking=non_blocking)
249
- if self.labels is not None:
250
- points.labels = self.labels.to(device, non_blocking=non_blocking)
251
- if self.batch_id is not None:
252
- points.batch_id = self.batch_id.to(device, non_blocking=non_blocking)
253
- return points
254
-
255
- def cuda(self, non_blocking: bool = False):
256
- r''' Moves the Points to the GPU. '''
257
-
258
- return self.to('cuda', non_blocking)
259
-
260
- def cpu(self):
261
- r''' Moves the Points to the CPU. '''
262
-
263
- return self.to('cpu')
264
-
265
- def save(self, filename: str, info: str = 'PNFL'):
266
- r''' Save the Points into npz or xyz files.
267
-
268
- Args:
269
- filename (str): The output filename.
270
- info (str): The infomation for saving: 'P' -> 'points', 'N' -> 'normals',
271
- 'F' -> 'features', 'L' -> 'labels', 'B' -> 'batch_id'.
272
- '''
273
-
274
- mapping = {
275
- 'P': ('points', self.points), 'N': ('normals', self.normals),
276
- 'F': ('features', self.features), 'L': ('labels', self.labels),
277
- 'B': ('batch_id', self.batch_id), }
278
-
279
- names, outs = [], []
280
- for key in info.upper():
281
- name, out = mapping[key]
282
- if out is not None:
283
- names.append(name)
284
- if out.dim() == 1:
285
- out = out.unsqueeze(1)
286
- outs.append(out.cpu().numpy())
287
-
288
- if filename.endswith('npz'):
289
- out_dict = dict(zip(names, outs))
290
- np.savez(filename, **out_dict)
291
- elif filename.endswith('xyz'):
292
- out_array = np.concatenate(outs, axis=1)
293
- np.savetxt(filename, out_array, fmt='%.6f')
294
- else:
295
- raise ValueError
296
-
297
-
298
- def merge_points(points: List['Points'], update_batch_info: bool = True):
299
- r''' Merges a list of points into one batch.
300
-
301
- Args:
302
- points (List[Octree]): A list of points to merge. The batch size of each
303
- points in the list is assumed to be 1, and the :obj:`batch_size`,
304
- :obj:`batch_id`, and :obj:`batch_npt` in the points are ignored.
305
- '''
306
-
307
- out = Points(torch.zeros(1, 3))
308
- out.points = torch.cat([p.points for p in points], dim=0)
309
- if points[0].normals is not None:
310
- out.normals = torch.cat([p.normals for p in points], dim=0)
311
- if points[0].features is not None:
312
- out.features = torch.cat([p.features for p in points], dim=0)
313
- if points[0].labels is not None:
314
- out.labels = torch.cat([p.labels for p in points], dim=0)
315
- out.device = points[0].device
316
-
317
- if update_batch_info:
318
- out.batch_size = len(points)
319
- out.batch_npt = torch.Tensor([p.npt for p in points]).long()
320
- out.batch_id = torch.cat([p.points.new_full((p.npt, 1), i)
321
- for i, p in enumerate(points)], dim=0)
322
- return out
1
+ # --------------------------------------------------------
2
+ # Octree-based Sparse Convolutional Neural Networks
3
+ # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Peng-Shuai Wang
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ import numpy as np
10
+ from typing import Optional, Union, List
11
+
12
+
13
+ class Points:
14
+ r''' Represents a point cloud and contains some elementary transformations.
15
+
16
+ Args:
17
+ points (torch.Tensor): The coordinates of the points with a shape of
18
+ :obj:`(N, 3)`, where :obj:`N` is the number of points.
19
+ normals (torch.Tensor or None): The point normals with a shape of
20
+ :obj:`(N, 3)`.
21
+ features (torch.Tensor or None): The point features with a shape of
22
+ :obj:`(N, C)`, where :obj:`C` is the channel of features.
23
+ labels (torch.Tensor or None): The point labels with a shape of
24
+ :obj:`(N, K)`, where :obj:`K` is the channel of labels.
25
+ batch_id (torch.Tensor or None): The batch indices for each point with a
26
+ shape of :obj:`(N, 1)`.
27
+ batch_size (int): The batch size.
28
+ '''
29
+
30
+ def __init__(self, points: torch.Tensor,
31
+ normals: Optional[torch.Tensor] = None,
32
+ features: Optional[torch.Tensor] = None,
33
+ labels: Optional[torch.Tensor] = None,
34
+ batch_id: Optional[torch.Tensor] = None,
35
+ batch_size: int = 1):
36
+ super().__init__()
37
+ self.points = points
38
+ self.normals = normals
39
+ self.features = features
40
+ self.labels = labels
41
+ self.batch_id = batch_id
42
+ self.batch_size = batch_size
43
+ self.device = points.device
44
+ self.batch_npt = None # valid after `merge_points`
45
+ self.check_input()
46
+
47
+ def check_input(self):
48
+ r''' Checks the input arguments.
49
+ '''
50
+
51
+ assert self.points.dim() == 2 and self.points.size(1) == 3
52
+ if self.normals is not None:
53
+ assert self.normals.dim() == 2 and self.normals.size(1) == 3
54
+ assert self.normals.size(0) == self.points.size(0)
55
+ if self.features is not None:
56
+ assert self.features.dim() == 2
57
+ assert self.features.size(0) == self.points.size(0)
58
+ if self.labels is not None:
59
+ assert self.labels.dim() == 2 or self.labels.dim() == 1
60
+ assert self.labels.size(0) == self.points.size(0)
61
+ if self.labels.dim() == 1:
62
+ self.labels = self.labels.unsqueeze(1)
63
+ if self.batch_id is not None:
64
+ assert self.batch_id.dim() == 2 or self.batch_id.dim() == 1
65
+ assert self.batch_id.size(0) == self.points.size(0)
66
+ if self.batch_id.dim() == 1:
67
+ self.batch_id = self.batch_id.unsqueeze(1)
68
+ assert self.batch_id.size(1) == 1
69
+
70
+ @property
71
+ def npt(self):
72
+ return self.points.shape[0]
73
+
74
+ def orient_normal(self, axis: str = 'x'):
75
+ r''' Orients the point normals along a given axis.
76
+
77
+ Args:
78
+ axis (int): The coordinate axes, choose from :obj:`x`, :obj:`y` and
79
+ :obj:`z`. (default: :obj:`x`)
80
+ '''
81
+
82
+ if self.normals is None:
83
+ return
84
+
85
+ axis_map = {'x': 0, 'y': 1, 'z': 2, 'xyz': 3}
86
+ idx = axis_map[axis]
87
+ if idx < 3:
88
+ flags = self.normals[:, idx] > 0
89
+ flags = flags.float() * 2.0 - 1.0 # [0, 1] -> [-1, 1]
90
+ self.normals = self.normals * flags.unsqueeze(1)
91
+ else:
92
+ self.normals.abs_()
93
+
94
+ def scale(self, factor: torch.Tensor):
95
+ r''' Rescales the point cloud.
96
+
97
+ Args:
98
+ factor (torch.Tensor): The scale factor with shape :obj:`(3,)`.
99
+ '''
100
+
101
+ non_zero = (factor != 0).all()
102
+ all_ones = (factor == 1.0).all()
103
+ non_uniform = (factor != factor[0]).any()
104
+ assert non_zero, 'The scale factor must not constain 0.'
105
+ if all_ones: return
106
+
107
+ factor = factor.to(self.device)
108
+ self.points = self.points * factor
109
+ if self.normals is not None and non_uniform:
110
+ ifactor = 1.0 / factor
111
+ self.normals = self.normals * ifactor
112
+ norm2 = torch.sqrt(torch.sum(self.normals ** 2, dim=1, keepdim=True))
113
+ self.normals = self.normals / torch.clamp(norm2, min=1.0e-12)
114
+
115
+ def rotate(self, angle: torch.Tensor):
116
+ r''' Rotates the point cloud.
117
+
118
+ Args:
119
+ angle (torch.Tensor): The rotation angles in radian with shape :obj:`(3,)`.
120
+ '''
121
+
122
+ cos, sin = angle.cos(), angle.sin()
123
+ # rotx, roty, rotz are actually the transpose of the rotation matrices
124
+ rotx = torch.Tensor([[1, 0, 0], [0, cos[0], sin[0]], [0, -sin[0], cos[0]]])
125
+ roty = torch.Tensor([[cos[1], 0, -sin[1]], [0, 1, 0], [sin[1], 0, cos[1]]])
126
+ rotz = torch.Tensor([[cos[2], sin[2], 0], [-sin[2], cos[2], 0], [0, 0, 1]])
127
+ rot = rotx @ roty @ rotz
128
+
129
+ rot = rot.to(self.device)
130
+ self.points = self.points @ rot
131
+ if self.normals is not None:
132
+ self.normals = self.normals @ rot
133
+
134
+ def translate(self, dis: torch.Tensor):
135
+ r''' Translates the point cloud.
136
+
137
+ Args:
138
+ dis (torch.Tensor): The displacement with shape :obj:`(3,)`.
139
+ '''
140
+
141
+ dis = dis.to(self.device)
142
+ self.points = self.points + dis
143
+
144
+ def flip(self, axis: str):
145
+ r''' Flips the point cloud along the given :attr:`axis`.
146
+
147
+ Args:
148
+ axis (str): The flipping axis, choosen from :obj:`x`, :obj:`y`, and :obj`z`.
149
+ '''
150
+
151
+ axis_map = {'x': 0, 'y': 1, 'z': 2}
152
+ for x in axis:
153
+ idx = axis_map[x]
154
+ self.points[:, idx] *= -1.0
155
+ if self.normals is not None:
156
+ self.normals[:, idx] *= -1.0
157
+
158
+ def clip(self, min: float = -1.0, max: float = 1.0, esp: float = 0.01):
159
+ r''' Clips the point cloud to :obj:`[min+esp, max-esp]` and returns the mask.
160
+
161
+ Args:
162
+ min (float): The minimum value to clip.
163
+ max (float): The maximum value to clip.
164
+ esp (float): The margin.
165
+ '''
166
+
167
+ mask = self.inbox_mask(min + esp, max - esp)
168
+ tmp = self.__getitem__(mask)
169
+ self.__dict__.update(tmp.__dict__)
170
+ return mask
171
+
172
+ def __getitem__(self, mask: torch.Tensor):
173
+ r''' Slices the point cloud according a given :attr:`mask`.
174
+ '''
175
+
176
+ dummy_pts = torch.zeros(1, 3, device=self.device)
177
+ out = Points(dummy_pts, batch_size=self.batch_size)
178
+
179
+ out.points = self.points[mask]
180
+ if self.normals is not None:
181
+ out.normals = self.normals[mask]
182
+ if self.features is not None:
183
+ out.features = self.features[mask]
184
+ if self.labels is not None:
185
+ out.labels = self.labels[mask]
186
+ if self.batch_id is not None:
187
+ out.batch_id = self.batch_id[mask]
188
+ return out
189
+
190
+ def inbox_mask(self, bbmin: Union[float, torch.Tensor] = -1.0,
191
+ bbmax: Union[float, torch.Tensor] = 1.0):
192
+ r''' Returns a mask indicating whether the points are within the specified
193
+ bounding box or not.
194
+ '''
195
+
196
+ mask_min = torch.all(self.points > bbmin, dim=1)
197
+ mask_max = torch.all(self.points < bbmax, dim=1)
198
+ mask = torch.logical_and(mask_min, mask_max)
199
+ return mask
200
+
201
+ def bbox(self):
202
+ r''' Returns the bounding box.
203
+ '''
204
+
205
+ # torch.min and torch.max return (value, indices)
206
+ bbmin = self.points.min(dim=0)
207
+ bbmax = self.points.max(dim=0)
208
+ return bbmin[0], bbmax[0]
209
+
210
+ def normalize(self, bbmin: torch.Tensor, bbmax: torch.Tensor,
211
+ scale: float = 1.0):
212
+ r''' Normalizes the point cloud to :obj:`[-scale, scale]`.
213
+
214
+ Args:
215
+ bbmin (torch.Tensor): The minimum coordinates of the bounding box.
216
+ bbmax (torch.Tensor): The maximum coordinates of the bounding box.
217
+ scale (float): The scale factor
218
+ '''
219
+
220
+ center = (bbmin + bbmax) * 0.5
221
+ box_size = (bbmax - bbmin).max() + 1.0e-6
222
+ self.points = (self.points - center) * (2.0 * scale / box_size)
223
+
224
+ def to(self, device: Union[torch.device, str], non_blocking: bool = False):
225
+ r''' Moves the Points to a specified device.
226
+
227
+ Args:
228
+ device (torch.device or str): The destination device.
229
+ non_blocking (bool): If True and the source is in pinned memory, the copy
230
+ will be asynchronous with respect to the host. Otherwise, the argument
231
+ has no effect. Default: False.
232
+ '''
233
+
234
+ if isinstance(device, str):
235
+ device = torch.device(device)
236
+
237
+ # If on the save device, directly retrun self
238
+ if self.device == device:
239
+ return self
240
+
241
+ # Construct a new Points on the specified device
242
+ points = Points(torch.zeros(1, 3, device=device))
243
+ points.batch_size = self.batch_size
244
+ points.batch_npt = self.batch_npt
245
+ points.points = self.points.to(device, non_blocking=non_blocking)
246
+ if self.normals is not None:
247
+ points.normals = self.normals.to(device, non_blocking=non_blocking)
248
+ if self.features is not None:
249
+ points.features = self.features.to(device, non_blocking=non_blocking)
250
+ if self.labels is not None:
251
+ points.labels = self.labels.to(device, non_blocking=non_blocking)
252
+ if self.batch_id is not None:
253
+ points.batch_id = self.batch_id.to(device, non_blocking=non_blocking)
254
+ return points
255
+
256
+ def cuda(self, non_blocking: bool = False):
257
+ r''' Moves the Points to the GPU. '''
258
+
259
+ return self.to('cuda', non_blocking)
260
+
261
+ def cpu(self):
262
+ r''' Moves the Points to the CPU. '''
263
+
264
+ return self.to('cpu')
265
+
266
+ def save(self, filename: str, info: str = 'PNFL'):
267
+ r''' Save the Points into npz or xyz files.
268
+
269
+ Args:
270
+ filename (str): The output filename.
271
+ info (str): The infomation for saving: 'P' -> 'points', 'N' -> 'normals',
272
+ 'F' -> 'features', 'L' -> 'labels', 'B' -> 'batch_id'.
273
+ '''
274
+
275
+ mapping = {
276
+ 'P': ('points', self.points), 'N': ('normals', self.normals),
277
+ 'F': ('features', self.features), 'L': ('labels', self.labels),
278
+ 'B': ('batch_id', self.batch_id), }
279
+
280
+ names, outs = [], []
281
+ for key in info.upper():
282
+ name, out = mapping[key]
283
+ if out is not None:
284
+ names.append(name)
285
+ if out.dim() == 1:
286
+ out = out.unsqueeze(1)
287
+ outs.append(out.cpu().numpy())
288
+
289
+ if filename.endswith('npz'):
290
+ out_dict = dict(zip(names, outs))
291
+ np.savez(filename, **out_dict)
292
+ elif filename.endswith('xyz'):
293
+ out_array = np.concatenate(outs, axis=1)
294
+ np.savetxt(filename, out_array, fmt='%.6f')
295
+ else:
296
+ raise ValueError
297
+
298
+
299
+ def merge_points(points: List['Points'], update_batch_info: bool = True):
300
+ r''' Merges a list of points into one batch.
301
+
302
+ Args:
303
+ points (List[Octree]): A list of points to merge. The batch size of each
304
+ points in the list is assumed to be 1, and the :obj:`batch_size`,
305
+ :obj:`batch_id`, and :obj:`batch_npt` in the points are ignored.
306
+ '''
307
+
308
+ out = Points(torch.zeros(1, 3))
309
+ out.points = torch.cat([p.points for p in points], dim=0)
310
+ if points[0].normals is not None:
311
+ out.normals = torch.cat([p.normals for p in points], dim=0)
312
+ if points[0].features is not None:
313
+ out.features = torch.cat([p.features for p in points], dim=0)
314
+ if points[0].labels is not None:
315
+ out.labels = torch.cat([p.labels for p in points], dim=0)
316
+ out.device = points[0].device
317
+
318
+ if update_batch_info:
319
+ out.batch_size = len(points)
320
+ out.batch_npt = torch.Tensor([p.npt for p in points]).long()
321
+ out.batch_id = torch.cat([p.points.new_full((p.npt, 1), i)
322
+ for i, p in enumerate(points)], dim=0)
323
+ return out