ocnn 2.2.0__py3-none-any.whl → 2.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ocnn/octree/points.py CHANGED
@@ -1,298 +1,317 @@
1
- # --------------------------------------------------------
2
- # Octree-based Sparse Convolutional Neural Networks
3
- # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
- # Licensed under The MIT License [see LICENSE for details]
5
- # Written by Peng-Shuai Wang
6
- # --------------------------------------------------------
7
-
8
- import torch
9
- import numpy as np
10
- from typing import Optional, Union, List
11
-
12
-
13
- class Points:
14
- r''' Represents a point cloud and contains some elementary transformations.
15
-
16
- Args:
17
- points (torch.Tensor): The coordinates of the points with a shape of
18
- :obj:`(N, 3)`, where :obj:`N` is the number of points.
19
- normals (torch.Tensor or None): The point normals with a shape of
20
- :obj:`(N, 3)`.
21
- features (torch.Tensor or None): The point features with a shape of
22
- :obj:`(N, C)`, where :obj:`C` is the channel of features.
23
- labels (torch.Tensor or None): The point labels with a shape of
24
- :obj:`(N, K)`, where :obj:`K` is the channel of labels.
25
- batch_id (torch.Tensor or None): The batch indices for each point with a
26
- shape of :obj:`(N, 1)`.
27
- batch_size (int): The batch size.
28
- '''
29
-
30
- def __init__(self, points: torch.Tensor,
31
- normals: Optional[torch.Tensor] = None,
32
- features: Optional[torch.Tensor] = None,
33
- labels: Optional[torch.Tensor] = None,
34
- batch_id: Optional[torch.Tensor] = None,
35
- batch_size: int = 1):
36
- super().__init__()
37
- self.points = points
38
- self.normals = normals
39
- self.features = features
40
- self.labels = labels
41
- self.batch_id = batch_id
42
- self.batch_size = batch_size
43
- self.device = points.device
44
- self.batch_npt = None # valid after `merge_points`
45
-
46
- @property
47
- def npt(self):
48
- return self.points.shape[0]
49
-
50
- def orient_normal(self, axis: str = 'x'):
51
- r''' Orients the point normals along a given axis.
52
-
53
- Args:
54
- axis (int): The coordinate axes, choose from :obj:`x`, :obj:`y` and
55
- :obj:`z`. (default: :obj:`x`)
56
- '''
57
-
58
- if self.normals is None:
59
- return
60
-
61
- axis_map = {'x': 0, 'y': 1, 'z': 2, 'xyz': 3}
62
- idx = axis_map[axis]
63
- if idx < 3:
64
- flags = self.normals[:, idx] > 0
65
- flags = flags.float() * 2.0 - 1.0 # [0, 1] -> [-1, 1]
66
- self.normals = self.normals * flags.unsqueeze(1)
67
- else:
68
- self.normals.abs_()
69
-
70
- def scale(self, factor: torch.Tensor):
71
- r''' Rescales the point cloud.
72
-
73
- Args:
74
- factor (torch.Tensor): The scale factor with shape :obj:`(3,)`.
75
- '''
76
-
77
- non_zero = (factor != 0).all()
78
- all_ones = (factor == 1.0).all()
79
- non_uniform = (factor != factor[0]).any()
80
- assert non_zero, 'The scale factor must not constain 0.'
81
- if all_ones: return
82
-
83
- factor = factor.to(self.device)
84
- self.points = self.points * factor
85
- if self.normals is not None and non_uniform:
86
- ifactor = 1.0 / factor
87
- self.normals = self.normals * ifactor
88
- norm2 = torch.sqrt(torch.sum(self.normals ** 2, dim=1, keepdim=True))
89
- self.normals = self.normals / torch.clamp(norm2, min=1.0e-12)
90
-
91
- def rotate(self, angle: torch.Tensor):
92
- r''' Rotates the point cloud.
93
-
94
- Args:
95
- angle (torch.Tensor): The rotation angles in radian with shape :obj:`(3,)`.
96
- '''
97
-
98
- cos, sin = angle.cos(), angle.sin()
99
- # rotx, roty, rotz are actually the transpose of the rotation matrices
100
- rotx = torch.Tensor([[1, 0, 0], [0, cos[0], sin[0]], [0, -sin[0], cos[0]]])
101
- roty = torch.Tensor([[cos[1], 0, -sin[1]], [0, 1, 0], [sin[1], 0, cos[1]]])
102
- rotz = torch.Tensor([[cos[2], sin[2], 0], [-sin[2], cos[2], 0], [0, 0, 1]])
103
- rot = rotx @ roty @ rotz
104
-
105
- rot = rot.to(self.device)
106
- self.points = self.points @ rot
107
- if self.normals is not None:
108
- self.normals = self.normals @ rot
109
-
110
- def translate(self, dis: torch.Tensor):
111
- r''' Translates the point cloud.
112
-
113
- Args:
114
- dis (torch.Tensor): The displacement with shape :obj:`(3,)`.
115
- '''
116
-
117
- dis = dis.to(self.device)
118
- self.points = self.points + dis
119
-
120
- def flip(self, axis: str):
121
- r''' Flips the point cloud along the given :attr:`axis`.
122
-
123
- Args:
124
- axis (str): The flipping axis, choosen from :obj:`x`, :obj:`y`, and :obj`z`.
125
- '''
126
-
127
- axis_map = {'x': 0, 'y': 1, 'z': 2}
128
- for x in axis:
129
- idx = axis_map[x]
130
- self.points[:, idx] *= -1.0
131
- if self.normals is not None:
132
- self.normals[:, idx] *= -1.0
133
-
134
- def clip(self, min: float = -1.0, max: float = 1.0, esp: float = 0.01):
135
- r''' Clips the point cloud to :obj:`[min+esp, max-esp]` and returns the mask.
136
-
137
- Args:
138
- min (float): The minimum value to clip.
139
- max (float): The maximum value to clip.
140
- esp (float): The margin.
141
- '''
142
-
143
- mask = self.inbox_mask(min + esp, max - esp)
144
- tmp = self.__getitem__(mask)
145
- self.__dict__.update(tmp.__dict__)
146
- return mask
147
-
148
- def __getitem__(self, mask: torch.Tensor):
149
- r''' Slices the point cloud according a given :attr:`mask`.
150
- '''
151
-
152
- dummy_pts = torch.zeros(1, 3, device=self.device)
153
- out = Points(dummy_pts, batch_size=self.batch_size)
154
-
155
- out.points = self.points[mask]
156
- if self.normals is not None:
157
- out.normals = self.normals[mask]
158
- if self.features is not None:
159
- out.features = self.features[mask]
160
- if self.labels is not None:
161
- out.labels = self.labels[mask]
162
- if self.batch_id is not None:
163
- out.batch_id = self.batch_id[mask]
164
- return out
165
-
166
- def inbox_mask(self, bbmin: Union[float, torch.Tensor] = -1.0,
167
- bbmax: Union[float, torch.Tensor] = 1.0):
168
- r''' Returns a mask indicating whether the points are within the specified
169
- bounding box or not.
170
- '''
171
-
172
- mask_min = torch.all(self.points > bbmin, dim=1)
173
- mask_max = torch.all(self.points < bbmax, dim=1)
174
- mask = torch.logical_and(mask_min, mask_max)
175
- return mask
176
-
177
- def bbox(self):
178
- r''' Returns the bounding box.
179
- '''
180
-
181
- # torch.min and torch.max return (value, indices)
182
- bbmin = self.points.min(dim=0)
183
- bbmax = self.points.max(dim=0)
184
- return bbmin[0], bbmax[0]
185
-
186
- def normalize(self, bbmin: torch.Tensor, bbmax: torch.Tensor,
187
- scale: float = 1.0):
188
- r''' Normalizes the point cloud to :obj:`[-scale, scale]`.
189
-
190
- Args:
191
- bbmin (torch.Tensor): The minimum coordinates of the bounding box.
192
- bbmax (torch.Tensor): The maximum coordinates of the bounding box.
193
- scale (float): The scale factor
194
- '''
195
-
196
- center = (bbmin + bbmax) * 0.5
197
- box_size = (bbmax - bbmin).max() + 1.0e-6
198
- self.points = (self.points - center) * (2.0 * scale / box_size)
199
-
200
- def to(self, device: Union[torch.device, str], non_blocking: bool = False):
201
- r''' Moves the Points to a specified device.
202
-
203
- Args:
204
- device (torch.device or str): The destination device.
205
- non_blocking (bool): If True and the source is in pinned memory, the copy
206
- will be asynchronous with respect to the host. Otherwise, the argument
207
- has no effect. Default: False.
208
- '''
209
-
210
- if isinstance(device, str):
211
- device = torch.device(device)
212
-
213
- # If on the save device, directly retrun self
214
- if self.device == device:
215
- return self
216
-
217
- # Construct a new Points on the specified device
218
- points = Points(torch.zeros(1, 3, device=device))
219
- points.batch_npt = self.batch_npt
220
- points.points = self.points.to(device, non_blocking=non_blocking)
221
- if self.normals is not None:
222
- points.normals = self.normals.to(device, non_blocking=non_blocking)
223
- if self.features is not None:
224
- points.features = self.features.to(device, non_blocking=non_blocking)
225
- if self.labels is not None:
226
- points.labels = self.labels.to(device, non_blocking=non_blocking)
227
- if self.batch_id is not None:
228
- points.batch_id = self.batch_id.to(device, non_blocking=non_blocking)
229
- return points
230
-
231
- def cuda(self, non_blocking: bool = False):
232
- r''' Moves the Points to the GPU. '''
233
-
234
- return self.to('cuda', non_blocking)
235
-
236
- def cpu(self):
237
- r''' Moves the Points to the CPU. '''
238
-
239
- return self.to('cpu')
240
-
241
- def save(self, filename: str, info: str = 'PNFL'):
242
- r''' Save the Points into npz or xyz files.
243
-
244
- Args:
245
- filename (str): The output filename.
246
- info (str): The infomation for saving: 'P' -> 'points', 'N' -> 'normals',
247
- 'F' -> 'features', 'L' -> 'labels', 'B' -> 'batch_id'.
248
- '''
249
-
250
- mapping = {
251
- 'P': ('points', self.points), 'N': ('normals', self.normals),
252
- 'F': ('features', self.features), 'L': ('labels', self.labels),
253
- 'B': ('batch_id', self.batch_id), }
254
-
255
- names, outs = [], []
256
- for key in info.upper():
257
- name, out = mapping[key]
258
- if out is not None:
259
- names.append(name)
260
- if out.dim() == 1:
261
- out = out.unsqueeze(1)
262
- outs.append(out.cpu().numpy())
263
-
264
- if filename.endswith('npz'):
265
- out_dict = dict(zip(names, outs))
266
- np.savez(filename, **out_dict)
267
- elif filename.endswith('xyz'):
268
- out_array = np.concatenate(outs, axis=1)
269
- np.savetxt(filename, out_array, fmt='%.6f')
270
- else:
271
- raise ValueError
272
-
273
-
274
- def merge_points(points: List['Points'], update_batch_info: bool = True):
275
- r''' Merges a list of points into one batch.
276
-
277
- Args:
278
- points (List[Octree]): A list of points to merge. The batch size of each
279
- points in the list is assumed to be 1, and the :obj:`batch_size`,
280
- :obj:`batch_id`, and :obj:`batch_npt` in the points are ignored.
281
- '''
282
-
283
- out = Points(torch.zeros(1, 3))
284
- out.points = torch.cat([p.points for p in points], dim=0)
285
- if points[0].normals is not None:
286
- out.normals = torch.cat([p.normals for p in points], dim=0)
287
- if points[0].features is not None:
288
- out.features = torch.cat([p.features for p in points], dim=0)
289
- if points[0].labels is not None:
290
- out.labels = torch.cat([p.labels for p in points], dim=0)
291
- out.device = points[0].device
292
-
293
- if update_batch_info:
294
- out.batch_size = len(points)
295
- out.batch_npt = torch.Tensor([p.npt for p in points]).long()
296
- out.batch_id = torch.cat([p.points.new_full((p.npt, 1), i)
297
- for i, p in enumerate(points)], dim=0)
298
- return out
1
+ # --------------------------------------------------------
2
+ # Octree-based Sparse Convolutional Neural Networks
3
+ # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Peng-Shuai Wang
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ import numpy as np
10
+ from typing import Optional, Union, List
11
+
12
+
13
+ class Points:
14
+ r''' Represents a point cloud and contains some elementary transformations.
15
+
16
+ Args:
17
+ points (torch.Tensor): The coordinates of the points with a shape of
18
+ :obj:`(N, 3)`, where :obj:`N` is the number of points.
19
+ normals (torch.Tensor or None): The point normals with a shape of
20
+ :obj:`(N, 3)`.
21
+ features (torch.Tensor or None): The point features with a shape of
22
+ :obj:`(N, C)`, where :obj:`C` is the channel of features.
23
+ labels (torch.Tensor or None): The point labels with a shape of
24
+ :obj:`(N, K)`, where :obj:`K` is the channel of labels.
25
+ batch_id (torch.Tensor or None): The batch indices for each point with a
26
+ shape of :obj:`(N, 1)`.
27
+ batch_size (int): The batch size.
28
+ '''
29
+
30
+ def __init__(self, points: torch.Tensor,
31
+ normals: Optional[torch.Tensor] = None,
32
+ features: Optional[torch.Tensor] = None,
33
+ labels: Optional[torch.Tensor] = None,
34
+ batch_id: Optional[torch.Tensor] = None,
35
+ batch_size: int = 1):
36
+ super().__init__()
37
+ self.points = points
38
+ self.normals = normals
39
+ self.features = features
40
+ self.labels = labels
41
+ self.batch_id = batch_id
42
+ self.batch_size = batch_size
43
+ self.device = points.device
44
+ self.batch_npt = None # valid after `merge_points`
45
+ self.check_input()
46
+
47
+ def check_input(self):
48
+ r''' Checks the input arguments.
49
+ '''
50
+
51
+ assert self.points.dim() == 2 and self.points.size(1) == 3
52
+ if self.normals is not None:
53
+ assert self.normals.dim() == 2 and self.normals.size(1) == 3
54
+ assert self.normals.size(0) == self.points.size(0)
55
+ if self.features is not None:
56
+ assert self.features.dim() == 2
57
+ assert self.features.size(0) == self.points.size(0)
58
+ if self.labels is not None:
59
+ assert self.labels.dim() == 2
60
+ assert self.labels.size(0) == self.points.size(0)
61
+ if self.batch_id is not None:
62
+ assert self.batch_id.dim() == 2 and self.batch_id.size(1) == 1
63
+ assert self.batch_id.size(0) == self.points.size(0)
64
+
65
+ @property
66
+ def npt(self):
67
+ return self.points.shape[0]
68
+
69
+ def orient_normal(self, axis: str = 'x'):
70
+ r''' Orients the point normals along a given axis.
71
+
72
+ Args:
73
+ axis (int): The coordinate axes, choose from :obj:`x`, :obj:`y` and
74
+ :obj:`z`. (default: :obj:`x`)
75
+ '''
76
+
77
+ if self.normals is None:
78
+ return
79
+
80
+ axis_map = {'x': 0, 'y': 1, 'z': 2, 'xyz': 3}
81
+ idx = axis_map[axis]
82
+ if idx < 3:
83
+ flags = self.normals[:, idx] > 0
84
+ flags = flags.float() * 2.0 - 1.0 # [0, 1] -> [-1, 1]
85
+ self.normals = self.normals * flags.unsqueeze(1)
86
+ else:
87
+ self.normals.abs_()
88
+
89
+ def scale(self, factor: torch.Tensor):
90
+ r''' Rescales the point cloud.
91
+
92
+ Args:
93
+ factor (torch.Tensor): The scale factor with shape :obj:`(3,)`.
94
+ '''
95
+
96
+ non_zero = (factor != 0).all()
97
+ all_ones = (factor == 1.0).all()
98
+ non_uniform = (factor != factor[0]).any()
99
+ assert non_zero, 'The scale factor must not constain 0.'
100
+ if all_ones: return
101
+
102
+ factor = factor.to(self.device)
103
+ self.points = self.points * factor
104
+ if self.normals is not None and non_uniform:
105
+ ifactor = 1.0 / factor
106
+ self.normals = self.normals * ifactor
107
+ norm2 = torch.sqrt(torch.sum(self.normals ** 2, dim=1, keepdim=True))
108
+ self.normals = self.normals / torch.clamp(norm2, min=1.0e-12)
109
+
110
+ def rotate(self, angle: torch.Tensor):
111
+ r''' Rotates the point cloud.
112
+
113
+ Args:
114
+ angle (torch.Tensor): The rotation angles in radian with shape :obj:`(3,)`.
115
+ '''
116
+
117
+ cos, sin = angle.cos(), angle.sin()
118
+ # rotx, roty, rotz are actually the transpose of the rotation matrices
119
+ rotx = torch.Tensor([[1, 0, 0], [0, cos[0], sin[0]], [0, -sin[0], cos[0]]])
120
+ roty = torch.Tensor([[cos[1], 0, -sin[1]], [0, 1, 0], [sin[1], 0, cos[1]]])
121
+ rotz = torch.Tensor([[cos[2], sin[2], 0], [-sin[2], cos[2], 0], [0, 0, 1]])
122
+ rot = rotx @ roty @ rotz
123
+
124
+ rot = rot.to(self.device)
125
+ self.points = self.points @ rot
126
+ if self.normals is not None:
127
+ self.normals = self.normals @ rot
128
+
129
+ def translate(self, dis: torch.Tensor):
130
+ r''' Translates the point cloud.
131
+
132
+ Args:
133
+ dis (torch.Tensor): The displacement with shape :obj:`(3,)`.
134
+ '''
135
+
136
+ dis = dis.to(self.device)
137
+ self.points = self.points + dis
138
+
139
+ def flip(self, axis: str):
140
+ r''' Flips the point cloud along the given :attr:`axis`.
141
+
142
+ Args:
143
+ axis (str): The flipping axis, choosen from :obj:`x`, :obj:`y`, and :obj`z`.
144
+ '''
145
+
146
+ axis_map = {'x': 0, 'y': 1, 'z': 2}
147
+ for x in axis:
148
+ idx = axis_map[x]
149
+ self.points[:, idx] *= -1.0
150
+ if self.normals is not None:
151
+ self.normals[:, idx] *= -1.0
152
+
153
+ def clip(self, min: float = -1.0, max: float = 1.0, esp: float = 0.01):
154
+ r''' Clips the point cloud to :obj:`[min+esp, max-esp]` and returns the mask.
155
+
156
+ Args:
157
+ min (float): The minimum value to clip.
158
+ max (float): The maximum value to clip.
159
+ esp (float): The margin.
160
+ '''
161
+
162
+ mask = self.inbox_mask(min + esp, max - esp)
163
+ tmp = self.__getitem__(mask)
164
+ self.__dict__.update(tmp.__dict__)
165
+ return mask
166
+
167
+ def __getitem__(self, mask: torch.Tensor):
168
+ r''' Slices the point cloud according a given :attr:`mask`.
169
+ '''
170
+
171
+ dummy_pts = torch.zeros(1, 3, device=self.device)
172
+ out = Points(dummy_pts, batch_size=self.batch_size)
173
+
174
+ out.points = self.points[mask]
175
+ if self.normals is not None:
176
+ out.normals = self.normals[mask]
177
+ if self.features is not None:
178
+ out.features = self.features[mask]
179
+ if self.labels is not None:
180
+ out.labels = self.labels[mask]
181
+ if self.batch_id is not None:
182
+ out.batch_id = self.batch_id[mask]
183
+ return out
184
+
185
+ def inbox_mask(self, bbmin: Union[float, torch.Tensor] = -1.0,
186
+ bbmax: Union[float, torch.Tensor] = 1.0):
187
+ r''' Returns a mask indicating whether the points are within the specified
188
+ bounding box or not.
189
+ '''
190
+
191
+ mask_min = torch.all(self.points > bbmin, dim=1)
192
+ mask_max = torch.all(self.points < bbmax, dim=1)
193
+ mask = torch.logical_and(mask_min, mask_max)
194
+ return mask
195
+
196
+ def bbox(self):
197
+ r''' Returns the bounding box.
198
+ '''
199
+
200
+ # torch.min and torch.max return (value, indices)
201
+ bbmin = self.points.min(dim=0)
202
+ bbmax = self.points.max(dim=0)
203
+ return bbmin[0], bbmax[0]
204
+
205
+ def normalize(self, bbmin: torch.Tensor, bbmax: torch.Tensor,
206
+ scale: float = 1.0):
207
+ r''' Normalizes the point cloud to :obj:`[-scale, scale]`.
208
+
209
+ Args:
210
+ bbmin (torch.Tensor): The minimum coordinates of the bounding box.
211
+ bbmax (torch.Tensor): The maximum coordinates of the bounding box.
212
+ scale (float): The scale factor
213
+ '''
214
+
215
+ center = (bbmin + bbmax) * 0.5
216
+ box_size = (bbmax - bbmin).max() + 1.0e-6
217
+ self.points = (self.points - center) * (2.0 * scale / box_size)
218
+
219
+ def to(self, device: Union[torch.device, str], non_blocking: bool = False):
220
+ r''' Moves the Points to a specified device.
221
+
222
+ Args:
223
+ device (torch.device or str): The destination device.
224
+ non_blocking (bool): If True and the source is in pinned memory, the copy
225
+ will be asynchronous with respect to the host. Otherwise, the argument
226
+ has no effect. Default: False.
227
+ '''
228
+
229
+ if isinstance(device, str):
230
+ device = torch.device(device)
231
+
232
+ # If on the save device, directly retrun self
233
+ if self.device == device:
234
+ return self
235
+
236
+ # Construct a new Points on the specified device
237
+ points = Points(torch.zeros(1, 3, device=device))
238
+ points.batch_npt = self.batch_npt
239
+ points.points = self.points.to(device, non_blocking=non_blocking)
240
+ if self.normals is not None:
241
+ points.normals = self.normals.to(device, non_blocking=non_blocking)
242
+ if self.features is not None:
243
+ points.features = self.features.to(device, non_blocking=non_blocking)
244
+ if self.labels is not None:
245
+ points.labels = self.labels.to(device, non_blocking=non_blocking)
246
+ if self.batch_id is not None:
247
+ points.batch_id = self.batch_id.to(device, non_blocking=non_blocking)
248
+ return points
249
+
250
+ def cuda(self, non_blocking: bool = False):
251
+ r''' Moves the Points to the GPU. '''
252
+
253
+ return self.to('cuda', non_blocking)
254
+
255
+ def cpu(self):
256
+ r''' Moves the Points to the CPU. '''
257
+
258
+ return self.to('cpu')
259
+
260
+ def save(self, filename: str, info: str = 'PNFL'):
261
+ r''' Save the Points into npz or xyz files.
262
+
263
+ Args:
264
+ filename (str): The output filename.
265
+ info (str): The infomation for saving: 'P' -> 'points', 'N' -> 'normals',
266
+ 'F' -> 'features', 'L' -> 'labels', 'B' -> 'batch_id'.
267
+ '''
268
+
269
+ mapping = {
270
+ 'P': ('points', self.points), 'N': ('normals', self.normals),
271
+ 'F': ('features', self.features), 'L': ('labels', self.labels),
272
+ 'B': ('batch_id', self.batch_id), }
273
+
274
+ names, outs = [], []
275
+ for key in info.upper():
276
+ name, out = mapping[key]
277
+ if out is not None:
278
+ names.append(name)
279
+ if out.dim() == 1:
280
+ out = out.unsqueeze(1)
281
+ outs.append(out.cpu().numpy())
282
+
283
+ if filename.endswith('npz'):
284
+ out_dict = dict(zip(names, outs))
285
+ np.savez(filename, **out_dict)
286
+ elif filename.endswith('xyz'):
287
+ out_array = np.concatenate(outs, axis=1)
288
+ np.savetxt(filename, out_array, fmt='%.6f')
289
+ else:
290
+ raise ValueError
291
+
292
+
293
+ def merge_points(points: List['Points'], update_batch_info: bool = True):
294
+ r''' Merges a list of points into one batch.
295
+
296
+ Args:
297
+ points (List[Octree]): A list of points to merge. The batch size of each
298
+ points in the list is assumed to be 1, and the :obj:`batch_size`,
299
+ :obj:`batch_id`, and :obj:`batch_npt` in the points are ignored.
300
+ '''
301
+
302
+ out = Points(torch.zeros(1, 3))
303
+ out.points = torch.cat([p.points for p in points], dim=0)
304
+ if points[0].normals is not None:
305
+ out.normals = torch.cat([p.normals for p in points], dim=0)
306
+ if points[0].features is not None:
307
+ out.features = torch.cat([p.features for p in points], dim=0)
308
+ if points[0].labels is not None:
309
+ out.labels = torch.cat([p.labels for p in points], dim=0)
310
+ out.device = points[0].device
311
+
312
+ if update_batch_info:
313
+ out.batch_size = len(points)
314
+ out.batch_npt = torch.Tensor([p.npt for p in points]).long()
315
+ out.batch_id = torch.cat([p.points.new_full((p.npt, 1), i)
316
+ for i, p in enumerate(points)], dim=0)
317
+ return out