ocnn 2.2.7__py3-none-any.whl → 2.2.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,115 +1,115 @@
1
- # --------------------------------------------------------
2
- # Octree-based Sparse Convolutional Neural Networks
3
- # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
- # Licensed under The MIT License [see LICENSE for details]
5
- # Written by Peng-Shuai Wang
6
- # --------------------------------------------------------
7
-
8
- import torch
9
- from typing import Optional, Union
10
-
11
-
12
- class KeyLUT:
13
-
14
- def __init__(self):
15
- r256 = torch.arange(256, dtype=torch.int64)
16
- r512 = torch.arange(512, dtype=torch.int64)
17
- zero = torch.zeros(256, dtype=torch.int64)
18
- device = torch.device('cpu')
19
-
20
- self._encode = {device: (self.xyz2key(r256, zero, zero, 8),
21
- self.xyz2key(zero, r256, zero, 8),
22
- self.xyz2key(zero, zero, r256, 8))}
23
- self._decode = {device: self.key2xyz(r512, 9)}
24
-
25
- def encode_lut(self, device=torch.device('cpu')):
26
- if device not in self._encode:
27
- cpu = torch.device('cpu')
28
- self._encode[device] = tuple(e.to(device) for e in self._encode[cpu])
29
- return self._encode[device]
30
-
31
- def decode_lut(self, device=torch.device('cpu')):
32
- if device not in self._decode:
33
- cpu = torch.device('cpu')
34
- self._decode[device] = tuple(e.to(device) for e in self._decode[cpu])
35
- return self._decode[device]
36
-
37
- def xyz2key(self, x, y, z, depth):
38
- key = torch.zeros_like(x)
39
- for i in range(depth):
40
- mask = 1 << i
41
- key = (key | ((x & mask) << (2 * i + 2)) |
42
- ((y & mask) << (2 * i + 1)) |
43
- ((z & mask) << (2 * i + 0)))
44
- return key
45
-
46
- def key2xyz(self, key, depth):
47
- x = torch.zeros_like(key)
48
- y = torch.zeros_like(key)
49
- z = torch.zeros_like(key)
50
- for i in range(depth):
51
- x = x | ((key & (1 << (3 * i + 2))) >> (2 * i + 2))
52
- y = y | ((key & (1 << (3 * i + 1))) >> (2 * i + 1))
53
- z = z | ((key & (1 << (3 * i + 0))) >> (2 * i + 0))
54
- return x, y, z
55
-
56
-
57
- _key_lut = KeyLUT()
58
-
59
-
60
- def xyz2key(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor,
61
- b: Optional[Union[torch.Tensor, int]] = None, depth: int = 16):
62
- r'''Encodes :attr:`x`, :attr:`y`, :attr:`z` coordinates to the shuffled keys
63
- based on pre-computed look up tables. The speed of this function is much
64
- faster than the method based on for-loop.
65
-
66
- Args:
67
- x (torch.Tensor): The x coordinate.
68
- y (torch.Tensor): The y coordinate.
69
- z (torch.Tensor): The z coordinate.
70
- b (torch.Tensor or int): The batch index of the coordinates, and should be
71
- smaller than 32768. If :attr:`b` is :obj:`torch.Tensor`, the size of
72
- :attr:`b` must be the same as :attr:`x`, :attr:`y`, and :attr:`z`.
73
- depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17).
74
- '''
75
-
76
- EX, EY, EZ = _key_lut.encode_lut(x.device)
77
- x, y, z = x.long(), y.long(), z.long()
78
-
79
- mask = 255 if depth > 8 else (1 << depth) - 1
80
- key = EX[x & mask] | EY[y & mask] | EZ[z & mask]
81
- if depth > 8:
82
- mask = (1 << (depth-8)) - 1
83
- key16 = EX[(x >> 8) & mask] | EY[(y >> 8) & mask] | EZ[(z >> 8) & mask]
84
- key = key16 << 24 | key
85
-
86
- if b is not None:
87
- b = b.long()
88
- key = b << 48 | key
89
-
90
- return key
91
-
92
-
93
- def key2xyz(key: torch.Tensor, depth: int = 16):
94
- r'''Decodes the shuffled key to :attr:`x`, :attr:`y`, :attr:`z` coordinates
95
- and the batch index based on pre-computed look up tables.
96
-
97
- Args:
98
- key (torch.Tensor): The shuffled key.
99
- depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17).
100
- '''
101
-
102
- DX, DY, DZ = _key_lut.decode_lut(key.device)
103
- x, y, z = torch.zeros_like(key), torch.zeros_like(key), torch.zeros_like(key)
104
-
105
- b = key >> 48
106
- key = key & ((1 << 48) - 1)
107
-
108
- n = (depth + 2) // 3
109
- for i in range(n):
110
- k = key >> (i * 9) & 511
111
- x = x | (DX[k] << (i * 3))
112
- y = y | (DY[k] << (i * 3))
113
- z = z | (DZ[k] << (i * 3))
114
-
115
- return x, y, z, b
1
+ # --------------------------------------------------------
2
+ # Octree-based Sparse Convolutional Neural Networks
3
+ # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Peng-Shuai Wang
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ from typing import Optional, Union
10
+
11
+
12
+ class KeyLUT:
13
+
14
+ def __init__(self):
15
+ r256 = torch.arange(256, dtype=torch.int64)
16
+ r512 = torch.arange(512, dtype=torch.int64)
17
+ zero = torch.zeros(256, dtype=torch.int64)
18
+ device = torch.device('cpu')
19
+
20
+ self._encode = {device: (self.xyz2key(r256, zero, zero, 8),
21
+ self.xyz2key(zero, r256, zero, 8),
22
+ self.xyz2key(zero, zero, r256, 8))}
23
+ self._decode = {device: self.key2xyz(r512, 9)}
24
+
25
+ def encode_lut(self, device=torch.device('cpu')):
26
+ if device not in self._encode:
27
+ cpu = torch.device('cpu')
28
+ self._encode[device] = tuple(e.to(device) for e in self._encode[cpu])
29
+ return self._encode[device]
30
+
31
+ def decode_lut(self, device=torch.device('cpu')):
32
+ if device not in self._decode:
33
+ cpu = torch.device('cpu')
34
+ self._decode[device] = tuple(e.to(device) for e in self._decode[cpu])
35
+ return self._decode[device]
36
+
37
+ def xyz2key(self, x, y, z, depth):
38
+ key = torch.zeros_like(x)
39
+ for i in range(depth):
40
+ mask = 1 << i
41
+ key = (key | ((x & mask) << (2 * i + 2)) |
42
+ ((y & mask) << (2 * i + 1)) |
43
+ ((z & mask) << (2 * i + 0)))
44
+ return key
45
+
46
+ def key2xyz(self, key, depth):
47
+ x = torch.zeros_like(key)
48
+ y = torch.zeros_like(key)
49
+ z = torch.zeros_like(key)
50
+ for i in range(depth):
51
+ x = x | ((key & (1 << (3 * i + 2))) >> (2 * i + 2))
52
+ y = y | ((key & (1 << (3 * i + 1))) >> (2 * i + 1))
53
+ z = z | ((key & (1 << (3 * i + 0))) >> (2 * i + 0))
54
+ return x, y, z
55
+
56
+
57
+ _key_lut = KeyLUT()
58
+
59
+
60
+ def xyz2key(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor,
61
+ b: Optional[Union[torch.Tensor, int]] = None, depth: int = 16):
62
+ r'''Encodes :attr:`x`, :attr:`y`, :attr:`z` coordinates to the shuffled keys
63
+ based on pre-computed look up tables. The speed of this function is much
64
+ faster than the method based on for-loop.
65
+
66
+ Args:
67
+ x (torch.Tensor): The x coordinate.
68
+ y (torch.Tensor): The y coordinate.
69
+ z (torch.Tensor): The z coordinate.
70
+ b (torch.Tensor or int): The batch index of the coordinates, and should be
71
+ smaller than 32768. If :attr:`b` is :obj:`torch.Tensor`, the size of
72
+ :attr:`b` must be the same as :attr:`x`, :attr:`y`, and :attr:`z`.
73
+ depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17).
74
+ '''
75
+
76
+ EX, EY, EZ = _key_lut.encode_lut(x.device)
77
+ x, y, z = x.long(), y.long(), z.long()
78
+
79
+ mask = 255 if depth > 8 else (1 << depth) - 1
80
+ key = EX[x & mask] | EY[y & mask] | EZ[z & mask]
81
+ if depth > 8:
82
+ mask = (1 << (depth-8)) - 1
83
+ key16 = EX[(x >> 8) & mask] | EY[(y >> 8) & mask] | EZ[(z >> 8) & mask]
84
+ key = key16 << 24 | key
85
+
86
+ if b is not None:
87
+ b = b.long()
88
+ key = b << 48 | key
89
+
90
+ return key
91
+
92
+
93
+ def key2xyz(key: torch.Tensor, depth: int = 16):
94
+ r'''Decodes the shuffled key to :attr:`x`, :attr:`y`, :attr:`z` coordinates
95
+ and the batch index based on pre-computed look up tables.
96
+
97
+ Args:
98
+ key (torch.Tensor): The shuffled key.
99
+ depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17).
100
+ '''
101
+
102
+ DX, DY, DZ = _key_lut.decode_lut(key.device)
103
+ x, y, z = torch.zeros_like(key), torch.zeros_like(key), torch.zeros_like(key)
104
+
105
+ b = key >> 48
106
+ key = key & ((1 << 48) - 1)
107
+
108
+ n = (depth + 2) // 3
109
+ for i in range(n):
110
+ k = key >> (i * 9) & 511
111
+ x = x | (DX[k] << (i * 3))
112
+ y = y | (DY[k] << (i * 3))
113
+ z = z | (DZ[k] << (i * 3))
114
+
115
+ return x, y, z, b
ocnn/utils.py CHANGED
@@ -1,205 +1,205 @@
1
- # --------------------------------------------------------
2
- # Octree-based Sparse Convolutional Neural Networks
3
- # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
- # Licensed under The MIT License [see LICENSE for details]
5
- # Written by Peng-Shuai Wang
6
- # --------------------------------------------------------
7
-
8
- import math
9
- import torch
10
- from typing import Optional
11
- from packaging import version
12
-
13
- import ocnn
14
-
15
-
16
- __all__ = ['trunc_div', 'meshgrid', 'cumsum', 'scatter_add', 'xavier_uniform_',
17
- 'resize_with_last_val', 'list2str', 'build_example_octree']
18
- classes = __all__
19
-
20
-
21
- def trunc_div(input, other):
22
- r''' Wraps :func:`torch.div` for compatibility. It rounds the results of the
23
- division towards zero and is equivalent to C-style integer division.
24
- '''
25
-
26
- larger_than_171 = version.parse(torch.__version__) > version.parse('1.7.1')
27
-
28
- if larger_than_171:
29
- return torch.div(input, other, rounding_mode='trunc')
30
- else:
31
- return torch.floor_divide(input, other)
32
-
33
-
34
- def meshgrid(*tensors, indexing: Optional[str] = None):
35
- r''' Wraps :func:`torch.meshgrid` for compatibility.
36
- '''
37
-
38
- larger_than_191 = version.parse(torch.__version__) > version.parse('1.9.1')
39
-
40
- if larger_than_191:
41
- return torch.meshgrid(*tensors, indexing=indexing)
42
- else:
43
- return torch.meshgrid(*tensors)
44
-
45
-
46
- def range_grid(min: int, max: int, device: torch.device = 'cpu'):
47
- r''' Builds a 3D mesh grid in :obj:`[min, max]` (:attr:`max` included).
48
-
49
- Args:
50
- min (int): The minimum value of the grid.
51
- max (int): The maximum value of the grid.
52
- device (torch.device, optional): The device to place the grid on.
53
-
54
- Returns:
55
- torch.Tensor: A 3D mesh grid tensor of shape (N, 3), where N is the total
56
- number of grid points.
57
-
58
- Example:
59
- >>> grid = range_grid(0, 1)
60
- >>> print(grid)
61
- tensor([[0, 0, 0],
62
- [0, 0, 1],
63
- [0, 1, 0],
64
- [0, 1, 1],
65
- [1, 0, 0],
66
- [1, 0, 1],
67
- [1, 1, 0],
68
- [1, 1, 1]])
69
- '''
70
-
71
- rng = torch.arange(min, max+1, dtype=torch.long, device=device)
72
- grid = meshgrid(rng, rng, rng, indexing='ij')
73
- grid = torch.stack(grid, dim=-1).view(-1, 3)
74
- return grid
75
-
76
-
77
- def cumsum(data: torch.Tensor, dim: int, exclusive: bool = False):
78
- r''' Extends :func:`torch.cumsum` with the input argument :attr:`exclusive`.
79
-
80
- Args:
81
- data (torch.Tensor): The input data.
82
- dim (int): The dimension to do the operation over.
83
- exclusive (bool): If false, the behavior is the same as :func:`torch.cumsum`;
84
- if true, returns the cumulative sum exclusively. Note that if ture,
85
- the shape of output tensor is larger by 1 than :attr:`data` in the
86
- dimension where the computation occurs.
87
- '''
88
-
89
- out = torch.cumsum(data, dim)
90
-
91
- if exclusive:
92
- size = list(data.size())
93
- size[dim] = 1
94
- zeros = out.new_zeros(size)
95
- out = torch.cat([zeros, out], dim)
96
- return out
97
-
98
-
99
- def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int):
100
- r''' Broadcast :attr:`src` according to :attr:`other`, originally from the
101
- library `pytorch_scatter`.
102
- '''
103
-
104
- if dim < 0:
105
- dim = other.dim() + dim
106
-
107
- if src.dim() == 1:
108
- for _ in range(0, dim):
109
- src = src.unsqueeze(0)
110
- for _ in range(src.dim(), other.dim()):
111
- src = src.unsqueeze(-1)
112
-
113
- src = src.expand_as(other)
114
- return src
115
-
116
-
117
- def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
118
- out: Optional[torch.Tensor] = None,
119
- dim_size: Optional[int] = None,) -> torch.Tensor:
120
- r''' Reduces all values from the :attr:`src` tensor into :attr:`out` at the
121
- indices specified in the :attr:`index` tensor along a given axis :attr:`dim`.
122
- This is just a wrapper of :func:`torch.scatter` in a boardcasting fashion.
123
-
124
- Args:
125
- src (torch.Tensor): The source tensor.
126
- index (torch.Tensor): The indices of elements to scatter.
127
- dim (torch.Tensor): The axis along which to index, (default: :obj:`-1`).
128
- out (torch.Tensor or None): The destination tensor.
129
- dim_size (int or None): If :attr:`out` is not given, automatically create
130
- output with size :attr:`dim_size` at dimension :attr:`dim`. If
131
- :attr:`dim_size` is not given, a minimal sized output tensor according
132
- to :obj:`index.max() + 1` is returned.
133
- '''
134
-
135
- index = broadcast(index, src, dim)
136
-
137
- if out is None:
138
- size = list(src.size())
139
- if dim_size is not None:
140
- size[dim] = dim_size
141
- elif index.numel() == 0:
142
- size[dim] = 0
143
- else:
144
- size[dim] = int(index.max()) + 1
145
- out = torch.zeros(size, dtype=src.dtype, device=src.device)
146
-
147
- return out.scatter_add_(dim, index, src)
148
-
149
-
150
- def xavier_uniform_(weights: torch.Tensor):
151
- r''' Initialize convolution weights with the same method as
152
- :obj:`torch.nn.init.xavier_uniform_`.
153
-
154
- :obj:`torch.nn.init.xavier_uniform_` initialize a tensor with shape
155
- :obj:`(out_c, in_c, kdim)`, which can not be used in :class:`ocnn.nn.OctreeConv`
156
- since the the shape of :attr:`OctreeConv.weights` is :obj:`(kdim, in_c,
157
- out_c)`.
158
- '''
159
-
160
- shape = weights.shape # (kernel_dim, in_conv, out_conv)
161
- fan_in = shape[0] * shape[1]
162
- fan_out = shape[0] * shape[2]
163
- std = math.sqrt(2.0 / float(fan_in + fan_out))
164
- a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
165
-
166
- torch.nn.init.uniform_(weights, -a, a)
167
-
168
-
169
- def resize_with_last_val(list_in: list, num: int = 3):
170
- r''' Resizes the number of elements of :attr:`list_in` to :attr:`num` with
171
- the last element of :attr:`list_in` if its number of elements is smaller
172
- than :attr:`num`.
173
- '''
174
-
175
- assert (type(list_in) is list and len(list_in) < num + 1)
176
- for _ in range(len(list_in), num):
177
- list_in.append(list_in[-1])
178
- return list_in
179
-
180
-
181
- def list2str(list_in: list):
182
- r''' Returns a string representation of :attr:`list_in`.
183
- '''
184
-
185
- out = [str(x) for x in list_in]
186
- return ''.join(out)
187
-
188
-
189
- def build_example_octree(depth: int = 5, full_depth: int = 2, pt_num: int = 3):
190
- r''' Builds an example octree on CPU from at most 3 points.
191
- '''
192
- # initialize the point cloud
193
- points = torch.Tensor([[-1, -1, -1], [0, 0, -1], [0.0625, 0.0625, -1]])
194
- normals = torch.Tensor([[1, 0, 0], [-1, 0, 0], [0, 1, 0]])
195
- features = torch.Tensor([[1, -1], [2, -2], [3, -3]])
196
- labels = torch.Tensor([[0], [2], [2]])
197
-
198
- assert pt_num <= 3 and pt_num > 0
199
- point_cloud = ocnn.octree.Points(
200
- points[:pt_num], normals[:pt_num], features[:pt_num], labels[:pt_num])
201
-
202
- # build octree
203
- octree = ocnn.octree.Octree(depth, full_depth)
204
- octree.build_octree(point_cloud)
205
- return octree
1
+ # --------------------------------------------------------
2
+ # Octree-based Sparse Convolutional Neural Networks
3
+ # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Peng-Shuai Wang
6
+ # --------------------------------------------------------
7
+
8
+ import math
9
+ import torch
10
+ from typing import Optional
11
+ from packaging import version
12
+
13
+ import ocnn
14
+
15
+
16
+ __all__ = ['trunc_div', 'meshgrid', 'cumsum', 'scatter_add', 'xavier_uniform_',
17
+ 'resize_with_last_val', 'list2str', 'build_example_octree']
18
+ classes = __all__
19
+
20
+
21
+ def trunc_div(input, other):
22
+ r''' Wraps :func:`torch.div` for compatibility. It rounds the results of the
23
+ division towards zero and is equivalent to C-style integer division.
24
+ '''
25
+
26
+ larger_than_171 = version.parse(torch.__version__) > version.parse('1.7.1')
27
+
28
+ if larger_than_171:
29
+ return torch.div(input, other, rounding_mode='trunc')
30
+ else:
31
+ return torch.floor_divide(input, other)
32
+
33
+
34
+ def meshgrid(*tensors, indexing: Optional[str] = None):
35
+ r''' Wraps :func:`torch.meshgrid` for compatibility.
36
+ '''
37
+
38
+ larger_than_191 = version.parse(torch.__version__) > version.parse('1.9.1')
39
+
40
+ if larger_than_191:
41
+ return torch.meshgrid(*tensors, indexing=indexing)
42
+ else:
43
+ return torch.meshgrid(*tensors)
44
+
45
+
46
+ def range_grid(min: int, max: int, device: torch.device = 'cpu'):
47
+ r''' Builds a 3D mesh grid in :obj:`[min, max]` (:attr:`max` included).
48
+
49
+ Args:
50
+ min (int): The minimum value of the grid.
51
+ max (int): The maximum value of the grid.
52
+ device (torch.device, optional): The device to place the grid on.
53
+
54
+ Returns:
55
+ torch.Tensor: A 3D mesh grid tensor of shape (N, 3), where N is the total
56
+ number of grid points.
57
+
58
+ Example:
59
+ >>> grid = range_grid(0, 1)
60
+ >>> print(grid)
61
+ tensor([[0, 0, 0],
62
+ [0, 0, 1],
63
+ [0, 1, 0],
64
+ [0, 1, 1],
65
+ [1, 0, 0],
66
+ [1, 0, 1],
67
+ [1, 1, 0],
68
+ [1, 1, 1]])
69
+ '''
70
+
71
+ rng = torch.arange(min, max+1, dtype=torch.long, device=device)
72
+ grid = meshgrid(rng, rng, rng, indexing='ij')
73
+ grid = torch.stack(grid, dim=-1).view(-1, 3)
74
+ return grid
75
+
76
+
77
+ def cumsum(data: torch.Tensor, dim: int, exclusive: bool = False):
78
+ r''' Extends :func:`torch.cumsum` with the input argument :attr:`exclusive`.
79
+
80
+ Args:
81
+ data (torch.Tensor): The input data.
82
+ dim (int): The dimension to do the operation over.
83
+ exclusive (bool): If false, the behavior is the same as :func:`torch.cumsum`;
84
+ if true, returns the cumulative sum exclusively. Note that if ture,
85
+ the shape of output tensor is larger by 1 than :attr:`data` in the
86
+ dimension where the computation occurs.
87
+ '''
88
+
89
+ out = torch.cumsum(data, dim)
90
+
91
+ if exclusive:
92
+ size = list(data.size())
93
+ size[dim] = 1
94
+ zeros = out.new_zeros(size)
95
+ out = torch.cat([zeros, out], dim)
96
+ return out
97
+
98
+
99
+ def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int):
100
+ r''' Broadcast :attr:`src` according to :attr:`other`, originally from the
101
+ library `pytorch_scatter`.
102
+ '''
103
+
104
+ if dim < 0:
105
+ dim = other.dim() + dim
106
+
107
+ if src.dim() == 1:
108
+ for _ in range(0, dim):
109
+ src = src.unsqueeze(0)
110
+ for _ in range(src.dim(), other.dim()):
111
+ src = src.unsqueeze(-1)
112
+
113
+ src = src.expand_as(other)
114
+ return src
115
+
116
+
117
+ def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
118
+ out: Optional[torch.Tensor] = None,
119
+ dim_size: Optional[int] = None,) -> torch.Tensor:
120
+ r''' Reduces all values from the :attr:`src` tensor into :attr:`out` at the
121
+ indices specified in the :attr:`index` tensor along a given axis :attr:`dim`.
122
+ This is just a wrapper of :func:`torch.scatter` in a boardcasting fashion.
123
+
124
+ Args:
125
+ src (torch.Tensor): The source tensor.
126
+ index (torch.Tensor): The indices of elements to scatter.
127
+ dim (torch.Tensor): The axis along which to index, (default: :obj:`-1`).
128
+ out (torch.Tensor or None): The destination tensor.
129
+ dim_size (int or None): If :attr:`out` is not given, automatically create
130
+ output with size :attr:`dim_size` at dimension :attr:`dim`. If
131
+ :attr:`dim_size` is not given, a minimal sized output tensor according
132
+ to :obj:`index.max() + 1` is returned.
133
+ '''
134
+
135
+ index = broadcast(index, src, dim)
136
+
137
+ if out is None:
138
+ size = list(src.size())
139
+ if dim_size is not None:
140
+ size[dim] = dim_size
141
+ elif index.numel() == 0:
142
+ size[dim] = 0
143
+ else:
144
+ size[dim] = int(index.max()) + 1
145
+ out = torch.zeros(size, dtype=src.dtype, device=src.device)
146
+
147
+ return out.scatter_add_(dim, index, src)
148
+
149
+
150
+ def xavier_uniform_(weights: torch.Tensor):
151
+ r''' Initialize convolution weights with the same method as
152
+ :obj:`torch.nn.init.xavier_uniform_`.
153
+
154
+ :obj:`torch.nn.init.xavier_uniform_` initialize a tensor with shape
155
+ :obj:`(out_c, in_c, kdim)`, which can not be used in :class:`ocnn.nn.OctreeConv`
156
+ since the the shape of :attr:`OctreeConv.weights` is :obj:`(kdim, in_c,
157
+ out_c)`.
158
+ '''
159
+
160
+ shape = weights.shape # (kernel_dim, in_conv, out_conv)
161
+ fan_in = shape[0] * shape[1]
162
+ fan_out = shape[0] * shape[2]
163
+ std = math.sqrt(2.0 / float(fan_in + fan_out))
164
+ a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
165
+
166
+ torch.nn.init.uniform_(weights, -a, a)
167
+
168
+
169
+ def resize_with_last_val(list_in: list, num: int = 3):
170
+ r''' Resizes the number of elements of :attr:`list_in` to :attr:`num` with
171
+ the last element of :attr:`list_in` if its number of elements is smaller
172
+ than :attr:`num`.
173
+ '''
174
+
175
+ assert (type(list_in) is list and len(list_in) < num + 1)
176
+ for _ in range(len(list_in), num):
177
+ list_in.append(list_in[-1])
178
+ return list_in
179
+
180
+
181
+ def list2str(list_in: list):
182
+ r''' Returns a string representation of :attr:`list_in`.
183
+ '''
184
+
185
+ out = [str(x) for x in list_in]
186
+ return ''.join(out)
187
+
188
+
189
+ def build_example_octree(depth: int = 5, full_depth: int = 2, pt_num: int = 3):
190
+ r''' Builds an example octree on CPU from at most 3 points.
191
+ '''
192
+ # initialize the point cloud
193
+ points = torch.Tensor([[-1, -1, -1], [0, 0, -1], [0.0625, 0.0625, -1]])
194
+ normals = torch.Tensor([[1, 0, 0], [-1, 0, 0], [0, 1, 0]])
195
+ features = torch.Tensor([[1, -1], [2, -2], [3, -3]])
196
+ labels = torch.Tensor([[0], [2], [2]])
197
+
198
+ assert pt_num <= 3 and pt_num > 0
199
+ point_cloud = ocnn.octree.Points(
200
+ points[:pt_num], normals[:pt_num], features[:pt_num], labels[:pt_num])
201
+
202
+ # build octree
203
+ octree = ocnn.octree.Octree(depth, full_depth)
204
+ octree.build_octree(point_cloud)
205
+ return octree