ocnn 2.2.2__tar.gz → 2.2.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {ocnn-2.2.2/ocnn.egg-info → ocnn-2.2.4}/PKG-INFO +16 -3
- {ocnn-2.2.2 → ocnn-2.2.4}/README.md +15 -2
- {ocnn-2.2.2 → ocnn-2.2.4}/ocnn/__init__.py +1 -1
- {ocnn-2.2.2 → ocnn-2.2.4}/ocnn/models/autoencoder.py +0 -1
- {ocnn-2.2.2 → ocnn-2.2.4}/ocnn/models/lenet.py +1 -1
- {ocnn-2.2.2 → ocnn-2.2.4}/ocnn/models/ounet.py +3 -2
- {ocnn-2.2.2 → ocnn-2.2.4}/ocnn/models/resnet.py +1 -1
- {ocnn-2.2.2 → ocnn-2.2.4}/ocnn/models/segnet.py +2 -2
- {ocnn-2.2.2 → ocnn-2.2.4}/ocnn/modules/__init__.py +9 -3
- {ocnn-2.2.2 → ocnn-2.2.4}/ocnn/modules/modules.py +117 -7
- {ocnn-2.2.2 → ocnn-2.2.4}/ocnn/modules/resblocks.py +39 -5
- {ocnn-2.2.2 → ocnn-2.2.4}/ocnn/nn/__init__.py +5 -3
- {ocnn-2.2.2 → ocnn-2.2.4}/ocnn/nn/octree_conv.py +18 -0
- {ocnn-2.2.2 → ocnn-2.2.4}/ocnn/nn/octree_dwconv.py +18 -0
- {ocnn-2.2.2 → ocnn-2.2.4}/ocnn/nn/octree_norm.py +46 -6
- {ocnn-2.2.2 → ocnn-2.2.4}/ocnn/octree/octree.py +10 -9
- {ocnn-2.2.2 → ocnn-2.2.4}/ocnn/octree/points.py +7 -2
- {ocnn-2.2.2 → ocnn-2.2.4}/ocnn/utils.py +7 -4
- {ocnn-2.2.2 → ocnn-2.2.4/ocnn.egg-info}/PKG-INFO +16 -3
- {ocnn-2.2.2 → ocnn-2.2.4}/setup.py +1 -1
- {ocnn-2.2.2 → ocnn-2.2.4}/LICENSE +0 -0
- {ocnn-2.2.2 → ocnn-2.2.4}/MANIFEST.in +0 -0
- {ocnn-2.2.2 → ocnn-2.2.4}/ocnn/dataset.py +0 -0
- {ocnn-2.2.2 → ocnn-2.2.4}/ocnn/models/__init__.py +0 -0
- {ocnn-2.2.2 → ocnn-2.2.4}/ocnn/models/hrnet.py +0 -0
- {ocnn-2.2.2 → ocnn-2.2.4}/ocnn/models/image2shape.py +0 -0
- {ocnn-2.2.2 → ocnn-2.2.4}/ocnn/models/unet.py +0 -0
- {ocnn-2.2.2 → ocnn-2.2.4}/ocnn/nn/octree2col.py +0 -0
- {ocnn-2.2.2 → ocnn-2.2.4}/ocnn/nn/octree2vox.py +0 -0
- {ocnn-2.2.2 → ocnn-2.2.4}/ocnn/nn/octree_align.py +0 -0
- {ocnn-2.2.2 → ocnn-2.2.4}/ocnn/nn/octree_drop.py +0 -0
- {ocnn-2.2.2 → ocnn-2.2.4}/ocnn/nn/octree_gconv.py +0 -0
- {ocnn-2.2.2 → ocnn-2.2.4}/ocnn/nn/octree_interp.py +0 -0
- {ocnn-2.2.2 → ocnn-2.2.4}/ocnn/nn/octree_pad.py +0 -0
- {ocnn-2.2.2 → ocnn-2.2.4}/ocnn/nn/octree_pool.py +0 -0
- {ocnn-2.2.2 → ocnn-2.2.4}/ocnn/octree/__init__.py +0 -0
- {ocnn-2.2.2 → ocnn-2.2.4}/ocnn/octree/shuffled_key.py +0 -0
- {ocnn-2.2.2 → ocnn-2.2.4}/ocnn.egg-info/SOURCES.txt +0 -0
- {ocnn-2.2.2 → ocnn-2.2.4}/ocnn.egg-info/dependency_links.txt +0 -0
- {ocnn-2.2.2 → ocnn-2.2.4}/ocnn.egg-info/not-zip-safe +0 -0
- {ocnn-2.2.2 → ocnn-2.2.4}/ocnn.egg-info/requires.txt +0 -0
- {ocnn-2.2.2 → ocnn-2.2.4}/ocnn.egg-info/top_level.txt +0 -0
- {ocnn-2.2.2 → ocnn-2.2.4}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: ocnn
|
|
3
|
-
Version: 2.2.
|
|
3
|
+
Version: 2.2.4
|
|
4
4
|
Summary: Octree-based Sparse Convolutional Neural Networks
|
|
5
5
|
Home-page: https://github.com/octree-nn/ocnn-pytorch
|
|
6
6
|
Author: Peng-Shuai Wang
|
|
@@ -23,6 +23,7 @@ Requires-Dist: packaging
|
|
|
23
23
|
|
|
24
24
|
[](https://ocnn-pytorch.readthedocs.io/en/latest/?badge=latest)
|
|
25
25
|
[](https://pepy.tech/project/ocnn)
|
|
26
|
+
[](https://pepy.tech/project/ocnn)
|
|
26
27
|
[](https://pypi.org/project/ocnn/)
|
|
27
28
|
|
|
28
29
|
This repository contains the **pure PyTorch**-based implementation of
|
|
@@ -43,14 +44,14 @@ The key difference is that our O-CNN uses the `octree` to index the sparse
|
|
|
43
44
|
voxels, while these 3 works use the `Hash Table`.
|
|
44
45
|
|
|
45
46
|
Our O-CNN is published in SIGGRAPH 2017, H-CNN is published in TVCG 2018,
|
|
46
|
-
SparseConvNet is published in CVPR 2018, and MinkowskiNet is published in
|
|
47
|
+
SparseConvNet is published in CVPR 2018, and MinkowskiNet is published in
|
|
47
48
|
CVPR 2019. Actually, our O-CNN was submitted to SIGGRAPH in the end of 2016 and
|
|
48
49
|
was officially accepted in March, 2017. The camera-ready version of our O-CNN was
|
|
49
50
|
submitted to SIGGRAPH in April, 2017. We just did not post our paper on Arxiv
|
|
50
51
|
during the review process of SIGGRAPH. Therefore, **the idea of constraining CNN
|
|
51
52
|
computation into sparse non-emtpry voxels is first proposed by our O-CNN**.
|
|
52
53
|
Currently, this type of 3D convolution is known as Sparse Convolution in the
|
|
53
|
-
research community.
|
|
54
|
+
research community.
|
|
54
55
|
|
|
55
56
|
## Key benefits of ocnn-pytorch
|
|
56
57
|
|
|
@@ -65,3 +66,15 @@ research community.
|
|
|
65
66
|
training settings, MinkowskiNet 0.4.3 takes 60 hours and MinkowskiNet 0.5.4
|
|
66
67
|
takes 30 hours.
|
|
67
68
|
|
|
69
|
+
## Citation
|
|
70
|
+
|
|
71
|
+
```bibtex
|
|
72
|
+
@article {Wang-2017-ocnn,
|
|
73
|
+
title = {{O-CNN}: Octree-based Convolutional Neural Networksfor {3D} Shape Analysis},
|
|
74
|
+
author = {Wang, Peng-Shuai and Liu, Yang and Guo, Yu-Xiao and Sun, Chun-Yu and Tong, Xin},
|
|
75
|
+
journal = {ACM Transactions on Graphics (SIGGRAPH)},
|
|
76
|
+
volume = {36},
|
|
77
|
+
number = {4},
|
|
78
|
+
year = {2017},
|
|
79
|
+
}
|
|
80
|
+
```
|
|
@@ -4,6 +4,7 @@
|
|
|
4
4
|
|
|
5
5
|
[](https://ocnn-pytorch.readthedocs.io/en/latest/?badge=latest)
|
|
6
6
|
[](https://pepy.tech/project/ocnn)
|
|
7
|
+
[](https://pepy.tech/project/ocnn)
|
|
7
8
|
[](https://pypi.org/project/ocnn/)
|
|
8
9
|
|
|
9
10
|
This repository contains the **pure PyTorch**-based implementation of
|
|
@@ -24,14 +25,14 @@ The key difference is that our O-CNN uses the `octree` to index the sparse
|
|
|
24
25
|
voxels, while these 3 works use the `Hash Table`.
|
|
25
26
|
|
|
26
27
|
Our O-CNN is published in SIGGRAPH 2017, H-CNN is published in TVCG 2018,
|
|
27
|
-
SparseConvNet is published in CVPR 2018, and MinkowskiNet is published in
|
|
28
|
+
SparseConvNet is published in CVPR 2018, and MinkowskiNet is published in
|
|
28
29
|
CVPR 2019. Actually, our O-CNN was submitted to SIGGRAPH in the end of 2016 and
|
|
29
30
|
was officially accepted in March, 2017. The camera-ready version of our O-CNN was
|
|
30
31
|
submitted to SIGGRAPH in April, 2017. We just did not post our paper on Arxiv
|
|
31
32
|
during the review process of SIGGRAPH. Therefore, **the idea of constraining CNN
|
|
32
33
|
computation into sparse non-emtpry voxels is first proposed by our O-CNN**.
|
|
33
34
|
Currently, this type of 3D convolution is known as Sparse Convolution in the
|
|
34
|
-
research community.
|
|
35
|
+
research community.
|
|
35
36
|
|
|
36
37
|
## Key benefits of ocnn-pytorch
|
|
37
38
|
|
|
@@ -46,3 +47,15 @@ research community.
|
|
|
46
47
|
training settings, MinkowskiNet 0.4.3 takes 60 hours and MinkowskiNet 0.5.4
|
|
47
48
|
takes 30 hours.
|
|
48
49
|
|
|
50
|
+
## Citation
|
|
51
|
+
|
|
52
|
+
```bibtex
|
|
53
|
+
@article {Wang-2017-ocnn,
|
|
54
|
+
title = {{O-CNN}: Octree-based Convolutional Neural Networksfor {3D} Shape Analysis},
|
|
55
|
+
author = {Wang, Peng-Shuai and Liu, Yang and Guo, Yu-Xiao and Sun, Chun-Yu and Tong, Xin},
|
|
56
|
+
journal = {ACM Transactions on Graphics (SIGGRAPH)},
|
|
57
|
+
volume = {36},
|
|
58
|
+
number = {4},
|
|
59
|
+
year = {2017},
|
|
60
|
+
}
|
|
61
|
+
```
|
|
@@ -26,7 +26,7 @@ class LeNet(torch.nn.Module):
|
|
|
26
26
|
self.convs = torch.nn.ModuleList([ocnn.modules.OctreeConvBnRelu(
|
|
27
27
|
channels[i], channels[i+1], nempty=nempty) for i in range(stages)])
|
|
28
28
|
self.pools = torch.nn.ModuleList([ocnn.nn.OctreeMaxPool(
|
|
29
|
-
nempty) for
|
|
29
|
+
nempty) for _ in range(stages)])
|
|
30
30
|
self.octree2voxel = ocnn.nn.Octree2Voxel(self.nempty)
|
|
31
31
|
self.header = torch.nn.Sequential(
|
|
32
32
|
torch.nn.Dropout(p=0.5), # drop1
|
|
@@ -23,12 +23,13 @@ class OUNet(AutoEncoder):
|
|
|
23
23
|
self.proj = None # remove this module used in AutoEncoder
|
|
24
24
|
|
|
25
25
|
def encoder(self, octree):
|
|
26
|
-
r''' The encoder network for extracting heirarchy features.
|
|
26
|
+
r''' The encoder network for extracting heirarchy features.
|
|
27
27
|
'''
|
|
28
28
|
|
|
29
29
|
convs = dict()
|
|
30
30
|
depth, full_depth = self.depth, self.full_depth
|
|
31
|
-
data =
|
|
31
|
+
data = octree.get_input_feature(self.feature, nempty=False)
|
|
32
|
+
assert data.size(1) == self.channel_in
|
|
32
33
|
convs[depth] = self.conv1(data, octree, depth)
|
|
33
34
|
for i, d in enumerate(range(depth, full_depth-1, -1)):
|
|
34
35
|
convs[d] = self.encoder_blks[i](convs[d], octree, d)
|
|
@@ -31,7 +31,7 @@ class ResNet(torch.nn.Module):
|
|
|
31
31
|
channels[i], channels[i+1], resblock_num, nempty=nempty)
|
|
32
32
|
for i in range(stages-1)])
|
|
33
33
|
self.pools = torch.nn.ModuleList([ocnn.nn.OctreeMaxPool(
|
|
34
|
-
nempty) for
|
|
34
|
+
nempty) for _ in range(stages-1)])
|
|
35
35
|
self.global_pool = ocnn.nn.OctreeGlobalPool(nempty)
|
|
36
36
|
# self.header = torch.nn.Linear(channels[-1], out_channels, bias=True)
|
|
37
37
|
self.header = torch.nn.Sequential(
|
|
@@ -29,7 +29,7 @@ class SegNet(torch.nn.Module):
|
|
|
29
29
|
[ocnn.modules.OctreeConvBnRelu(channels[i], channels[i+1], nempty=nempty)
|
|
30
30
|
for i in range(stages)])
|
|
31
31
|
self.pools = torch.nn.ModuleList(
|
|
32
|
-
[ocnn.nn.OctreeMaxPool(nempty, return_indices) for
|
|
32
|
+
[ocnn.nn.OctreeMaxPool(nempty, return_indices) for _ in range(stages)])
|
|
33
33
|
|
|
34
34
|
self.bottleneck = ocnn.modules.OctreeConvBnRelu(channels[-1], channels[-1])
|
|
35
35
|
|
|
@@ -38,7 +38,7 @@ class SegNet(torch.nn.Module):
|
|
|
38
38
|
[ocnn.modules.OctreeConvBnRelu(channels[i], channels[i+1], nempty=nempty)
|
|
39
39
|
for i in range(0, stages)])
|
|
40
40
|
self.unpools = torch.nn.ModuleList(
|
|
41
|
-
[ocnn.nn.OctreeMaxUnpool(nempty) for
|
|
41
|
+
[ocnn.nn.OctreeMaxUnpool(nempty) for _ in range(stages)])
|
|
42
42
|
|
|
43
43
|
self.octree_interp = ocnn.nn.OctreeInterp(interp, nempty)
|
|
44
44
|
self.header = torch.nn.Sequential(
|
|
@@ -7,14 +7,20 @@
|
|
|
7
7
|
|
|
8
8
|
from .modules import (InputFeature,
|
|
9
9
|
OctreeConvBn, OctreeConvBnRelu, OctreeDeconvBnRelu,
|
|
10
|
-
Conv1x1, Conv1x1Bn, Conv1x1BnRelu, FcBnRelu,
|
|
11
|
-
|
|
10
|
+
Conv1x1, Conv1x1Bn, Conv1x1BnRelu, FcBnRelu,
|
|
11
|
+
OctreeConvGn, OctreeConvGnRelu, OctreeDeconvGnRelu,
|
|
12
|
+
Conv1x1, Conv1x1Gn, Conv1x1GnRelu)
|
|
13
|
+
from .resblocks import (OctreeResBlock, OctreeResBlock2, OctreeResBlockGn,
|
|
14
|
+
OctreeResBlocks,)
|
|
12
15
|
|
|
13
16
|
__all__ = [
|
|
14
17
|
'InputFeature',
|
|
15
18
|
'OctreeConvBn', 'OctreeConvBnRelu', 'OctreeDeconvBnRelu',
|
|
16
19
|
'Conv1x1', 'Conv1x1Bn', 'Conv1x1BnRelu', 'FcBnRelu',
|
|
17
|
-
'
|
|
20
|
+
'OctreeConvGn', 'OctreeConvGnRelu', 'OctreeDeconvGnRelu',
|
|
21
|
+
'Conv1x1', 'Conv1x1Gn', 'Conv1x1GnRelu',
|
|
22
|
+
'OctreeResBlock', 'OctreeResBlock2', 'OctreeResBlockGn',
|
|
23
|
+
'OctreeResBlocks',
|
|
18
24
|
]
|
|
19
25
|
|
|
20
26
|
classes = __all__
|
|
@@ -9,7 +9,7 @@ import torch
|
|
|
9
9
|
import torch.utils.checkpoint
|
|
10
10
|
from typing import List
|
|
11
11
|
|
|
12
|
-
from ocnn.nn import OctreeConv, OctreeDeconv
|
|
12
|
+
from ocnn.nn import OctreeConv, OctreeDeconv, OctreeGroupNorm
|
|
13
13
|
from ocnn.octree import Octree
|
|
14
14
|
|
|
15
15
|
|
|
@@ -40,7 +40,7 @@ class OctreeConvBn(torch.nn.Module):
|
|
|
40
40
|
super().__init__()
|
|
41
41
|
self.conv = OctreeConv(
|
|
42
42
|
in_channels, out_channels, kernel_size, stride, nempty)
|
|
43
|
-
self.bn = torch.nn.BatchNorm1d(out_channels)
|
|
43
|
+
self.bn = torch.nn.BatchNorm1d(out_channels)
|
|
44
44
|
|
|
45
45
|
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
46
46
|
r''''''
|
|
@@ -62,7 +62,7 @@ class OctreeConvBnRelu(torch.nn.Module):
|
|
|
62
62
|
super().__init__()
|
|
63
63
|
self.conv = OctreeConv(
|
|
64
64
|
in_channels, out_channels, kernel_size, stride, nempty)
|
|
65
|
-
self.bn = torch.nn.BatchNorm1d(out_channels)
|
|
65
|
+
self.bn = torch.nn.BatchNorm1d(out_channels)
|
|
66
66
|
self.relu = torch.nn.ReLU(inplace=True)
|
|
67
67
|
|
|
68
68
|
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
@@ -86,7 +86,7 @@ class OctreeDeconvBnRelu(torch.nn.Module):
|
|
|
86
86
|
super().__init__()
|
|
87
87
|
self.deconv = OctreeDeconv(
|
|
88
88
|
in_channels, out_channels, kernel_size, stride, nempty)
|
|
89
|
-
self.bn = torch.nn.BatchNorm1d(out_channels)
|
|
89
|
+
self.bn = torch.nn.BatchNorm1d(out_channels)
|
|
90
90
|
self.relu = torch.nn.ReLU(inplace=True)
|
|
91
91
|
|
|
92
92
|
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
@@ -123,7 +123,7 @@ class Conv1x1Bn(torch.nn.Module):
|
|
|
123
123
|
def __init__(self, in_channels: int, out_channels: int):
|
|
124
124
|
super().__init__()
|
|
125
125
|
self.conv = Conv1x1(in_channels, out_channels, use_bias=False)
|
|
126
|
-
self.bn = torch.nn.BatchNorm1d(out_channels)
|
|
126
|
+
self.bn = torch.nn.BatchNorm1d(out_channels)
|
|
127
127
|
|
|
128
128
|
def forward(self, data: torch.Tensor):
|
|
129
129
|
r''''''
|
|
@@ -140,7 +140,7 @@ class Conv1x1BnRelu(torch.nn.Module):
|
|
|
140
140
|
def __init__(self, in_channels: int, out_channels: int):
|
|
141
141
|
super().__init__()
|
|
142
142
|
self.conv = Conv1x1(in_channels, out_channels, use_bias=False)
|
|
143
|
-
self.bn = torch.nn.BatchNorm1d(out_channels)
|
|
143
|
+
self.bn = torch.nn.BatchNorm1d(out_channels)
|
|
144
144
|
self.relu = torch.nn.ReLU(inplace=True)
|
|
145
145
|
|
|
146
146
|
def forward(self, data: torch.Tensor):
|
|
@@ -160,7 +160,7 @@ class FcBnRelu(torch.nn.Module):
|
|
|
160
160
|
super().__init__()
|
|
161
161
|
self.flatten = torch.nn.Flatten(start_dim=1)
|
|
162
162
|
self.fc = torch.nn.Linear(in_channels, out_channels, bias=False)
|
|
163
|
-
self.bn = torch.nn.BatchNorm1d(out_channels)
|
|
163
|
+
self.bn = torch.nn.BatchNorm1d(out_channels)
|
|
164
164
|
self.relu = torch.nn.ReLU(inplace=True)
|
|
165
165
|
|
|
166
166
|
def forward(self, data):
|
|
@@ -173,6 +173,116 @@ class FcBnRelu(torch.nn.Module):
|
|
|
173
173
|
return out
|
|
174
174
|
|
|
175
175
|
|
|
176
|
+
class OctreeConvGn(torch.nn.Module):
|
|
177
|
+
r''' A sequence of :class:`OctreeConv` and :obj:`OctreeGroupNorm`.
|
|
178
|
+
|
|
179
|
+
Please refer to :class:`ocnn.nn.OctreeConv` for details on the parameters.
|
|
180
|
+
'''
|
|
181
|
+
|
|
182
|
+
def __init__(self, in_channels: int, out_channels: int, group: int,
|
|
183
|
+
kernel_size: List[int] = [3], stride: int = 1,
|
|
184
|
+
nempty: bool = False):
|
|
185
|
+
super().__init__()
|
|
186
|
+
self.conv = OctreeConv(
|
|
187
|
+
in_channels, out_channels, kernel_size, stride, nempty)
|
|
188
|
+
self.gn = OctreeGroupNorm(out_channels, group=group, nempty=nempty)
|
|
189
|
+
|
|
190
|
+
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
191
|
+
r''''''
|
|
192
|
+
|
|
193
|
+
out = self.conv(data, octree, depth)
|
|
194
|
+
out = self.gn(out, octree, depth)
|
|
195
|
+
return out
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
class OctreeConvGnRelu(torch.nn.Module):
|
|
199
|
+
r''' A sequence of :class:`OctreeConv`, :obj:`OctreeGroupNorm`, and :obj:`Relu`.
|
|
200
|
+
|
|
201
|
+
Please refer to :class:`ocnn.nn.OctreeConv` for details on the parameters.
|
|
202
|
+
'''
|
|
203
|
+
|
|
204
|
+
def __init__(self, in_channels: int, out_channels: int, group: int,
|
|
205
|
+
kernel_size: List[int] = [3], stride: int = 1,
|
|
206
|
+
nempty: bool = False):
|
|
207
|
+
super().__init__()
|
|
208
|
+
self.stride = stride
|
|
209
|
+
self.conv = OctreeConv(
|
|
210
|
+
in_channels, out_channels, kernel_size, stride, nempty)
|
|
211
|
+
self.gn = OctreeGroupNorm(out_channels, group=group, nempty=nempty)
|
|
212
|
+
self.relu = torch.nn.ReLU(inplace=True)
|
|
213
|
+
|
|
214
|
+
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
215
|
+
r''''''
|
|
216
|
+
|
|
217
|
+
out = self.conv(data, octree, depth)
|
|
218
|
+
out = self.gn(out, octree, depth if self.stride == 1 else depth - 1)
|
|
219
|
+
out = self.relu(out)
|
|
220
|
+
return out
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
class OctreeDeconvGnRelu(torch.nn.Module):
|
|
224
|
+
r''' A sequence of :class:`OctreeDeconv`, :obj:`OctreeGroupNorm`, and :obj:`Relu`.
|
|
225
|
+
|
|
226
|
+
Please refer to :class:`ocnn.nn.OctreeConv` for details on the parameters.
|
|
227
|
+
'''
|
|
228
|
+
|
|
229
|
+
def __init__(self, in_channels: int, out_channels: int, group: int,
|
|
230
|
+
kernel_size: List[int] = [3], stride: int = 1,
|
|
231
|
+
nempty: bool = False):
|
|
232
|
+
super().__init__()
|
|
233
|
+
self.stride = stride
|
|
234
|
+
self.deconv = OctreeDeconv(
|
|
235
|
+
in_channels, out_channels, kernel_size, stride, nempty)
|
|
236
|
+
self.gn = OctreeGroupNorm(out_channels, group=group, nempty=nempty)
|
|
237
|
+
self.relu = torch.nn.ReLU(inplace=True)
|
|
238
|
+
|
|
239
|
+
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
240
|
+
r''''''
|
|
241
|
+
|
|
242
|
+
out = self.deconv(data, octree, depth)
|
|
243
|
+
out = self.gn(out, octree, depth if self.stride == 1 else depth + 1)
|
|
244
|
+
out = self.relu(out)
|
|
245
|
+
return out
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
class Conv1x1Gn(torch.nn.Module):
|
|
249
|
+
r''' A sequence of :class:`Conv1x1`, :class:`OctreeGroupNorm`.
|
|
250
|
+
'''
|
|
251
|
+
|
|
252
|
+
def __init__(self, in_channels: int, out_channels: int, group: int,
|
|
253
|
+
nempty: bool = False):
|
|
254
|
+
super().__init__()
|
|
255
|
+
self.conv = Conv1x1(in_channels, out_channels, use_bias=False)
|
|
256
|
+
self.gn = OctreeGroupNorm(out_channels, group=group, nempty=nempty)
|
|
257
|
+
|
|
258
|
+
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
259
|
+
r''''''
|
|
260
|
+
|
|
261
|
+
out = self.conv(data)
|
|
262
|
+
out = self.gn(out, octree, depth)
|
|
263
|
+
return out
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
class Conv1x1GnRelu(torch.nn.Module):
|
|
267
|
+
r''' A sequence of :class:`Conv1x1`, :class:`OctreeGroupNorm` and :class:`Relu`.
|
|
268
|
+
'''
|
|
269
|
+
|
|
270
|
+
def __init__(self, in_channels: int, out_channels: int, group: int,
|
|
271
|
+
nempty: bool = False):
|
|
272
|
+
super().__init__()
|
|
273
|
+
self.conv = Conv1x1(in_channels, out_channels, use_bias=False)
|
|
274
|
+
self.gn = OctreeGroupNorm(out_channels, group=group, nempty=nempty)
|
|
275
|
+
self.relu = torch.nn.ReLU(inplace=True)
|
|
276
|
+
|
|
277
|
+
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
278
|
+
r''''''
|
|
279
|
+
|
|
280
|
+
out = self.conv(data)
|
|
281
|
+
out = self.gn(out, octree, depth)
|
|
282
|
+
out = self.relu(out)
|
|
283
|
+
return out
|
|
284
|
+
|
|
285
|
+
|
|
176
286
|
class InputFeature(torch.nn.Module):
|
|
177
287
|
r''' Returns the initial input feature stored in octree.
|
|
178
288
|
|
|
@@ -10,7 +10,9 @@ import torch.utils.checkpoint
|
|
|
10
10
|
|
|
11
11
|
from ocnn.octree import Octree
|
|
12
12
|
from ocnn.nn import OctreeMaxPool
|
|
13
|
-
from ocnn.modules import Conv1x1BnRelu, OctreeConvBnRelu, Conv1x1Bn,
|
|
13
|
+
from ocnn.modules import (Conv1x1BnRelu, OctreeConvBnRelu, Conv1x1Bn,
|
|
14
|
+
OctreeConvBn, OctreeConvGnRelu, Conv1x1Gn,
|
|
15
|
+
OctreeConvGn,)
|
|
14
16
|
|
|
15
17
|
|
|
16
18
|
class OctreeResBlock(torch.nn.Module):
|
|
@@ -97,6 +99,38 @@ class OctreeResBlock2(torch.nn.Module):
|
|
|
97
99
|
return out
|
|
98
100
|
|
|
99
101
|
|
|
102
|
+
class OctreeResBlockGn(torch.nn.Module):
|
|
103
|
+
|
|
104
|
+
def __init__(self, in_channels: int, out_channels: int, stride: int = 1,
|
|
105
|
+
bottleneck: int = 4, nempty: bool = False, group: int = 32):
|
|
106
|
+
super().__init__()
|
|
107
|
+
self.in_channels = in_channels
|
|
108
|
+
self.out_channels = out_channels
|
|
109
|
+
self.stride = stride
|
|
110
|
+
channelb = int(out_channels / bottleneck)
|
|
111
|
+
|
|
112
|
+
if self.stride == 2:
|
|
113
|
+
self.maxpool = OctreeMaxPool(self.depth)
|
|
114
|
+
self.conv3x3a = OctreeConvGnRelu(in_channels, channelb, group, nempty=nempty)
|
|
115
|
+
self.conv3x3b = OctreeConvGn(channelb, out_channels, group, nempty=nempty)
|
|
116
|
+
if self.in_channels != self.out_channels:
|
|
117
|
+
self.conv1x1 = Conv1x1Gn(in_channels, out_channels, group)
|
|
118
|
+
self.relu = torch.nn.ReLU(inplace=True)
|
|
119
|
+
|
|
120
|
+
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
121
|
+
r''''''
|
|
122
|
+
|
|
123
|
+
if self.stride == 2:
|
|
124
|
+
data = self.maxpool(data, octree, depth)
|
|
125
|
+
depth = depth - 1
|
|
126
|
+
conv1 = self.conv3x3a(data, octree, depth)
|
|
127
|
+
conv2 = self.conv3x3b(conv1, octree, depth)
|
|
128
|
+
if self.in_channels != self.out_channels:
|
|
129
|
+
data = self.conv1x1(data, octree, depth)
|
|
130
|
+
out = self.relu(conv2 + data)
|
|
131
|
+
return out
|
|
132
|
+
|
|
133
|
+
|
|
100
134
|
class OctreeResBlocks(torch.nn.Module):
|
|
101
135
|
r''' A sequence of :attr:`resblk_num` ResNet blocks.
|
|
102
136
|
'''
|
|
@@ -108,9 +142,9 @@ class OctreeResBlocks(torch.nn.Module):
|
|
|
108
142
|
self.use_checkpoint = use_checkpoint
|
|
109
143
|
channels = [in_channels] + [out_channels] * resblk_num
|
|
110
144
|
|
|
111
|
-
self.resblks = torch.nn.ModuleList(
|
|
112
|
-
|
|
113
|
-
|
|
145
|
+
self.resblks = torch.nn.ModuleList([resblk(
|
|
146
|
+
channels[i], channels[i+1], 1, bottleneck, nempty)
|
|
147
|
+
for i in range(self.resblk_num)])
|
|
114
148
|
|
|
115
149
|
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
116
150
|
r''''''
|
|
@@ -118,7 +152,7 @@ class OctreeResBlocks(torch.nn.Module):
|
|
|
118
152
|
for i in range(self.resblk_num):
|
|
119
153
|
if self.use_checkpoint:
|
|
120
154
|
data = torch.utils.checkpoint.checkpoint(
|
|
121
|
-
self.resblks[i], data, octree, depth)
|
|
155
|
+
self.resblks[i], data, octree, depth, use_reentrant=False)
|
|
122
156
|
else:
|
|
123
157
|
data = self.resblks[i](data, octree, depth)
|
|
124
158
|
return data
|
|
@@ -15,8 +15,10 @@ from .octree_pool import (octree_max_pool, OctreeMaxPool,
|
|
|
15
15
|
octree_global_pool, OctreeGlobalPool,
|
|
16
16
|
octree_avg_pool, OctreeAvgPool,)
|
|
17
17
|
from .octree_conv import OctreeConv, OctreeDeconv
|
|
18
|
+
from .octree_gconv import OctreeGroupConv
|
|
18
19
|
from .octree_dwconv import OctreeDWConv
|
|
19
|
-
from .octree_norm import OctreeBatchNorm, OctreeGroupNorm,
|
|
20
|
+
from .octree_norm import (OctreeBatchNorm, OctreeGroupNorm,
|
|
21
|
+
OctreeInstanceNorm, OctreeNorm)
|
|
20
22
|
from .octree_drop import OctreeDropPath
|
|
21
23
|
from .octree_align import search_value, octree_align
|
|
22
24
|
|
|
@@ -32,9 +34,9 @@ __all__ = [
|
|
|
32
34
|
'OctreeMaxPool', 'OctreeMaxUnpool',
|
|
33
35
|
'OctreeGlobalPool', 'OctreeAvgPool',
|
|
34
36
|
'OctreeConv', 'OctreeDeconv',
|
|
35
|
-
'OctreeDWConv',
|
|
37
|
+
'OctreeGroupConv', 'OctreeDWConv',
|
|
36
38
|
'OctreeInterp', 'OctreeUpsample',
|
|
37
|
-
'OctreeInstanceNorm', 'OctreeBatchNorm', 'OctreeGroupNorm',
|
|
39
|
+
'OctreeInstanceNorm', 'OctreeBatchNorm', 'OctreeGroupNorm', 'OctreeNorm',
|
|
38
40
|
'OctreeDropPath',
|
|
39
41
|
'search_value', 'octree_align',
|
|
40
42
|
]
|
|
@@ -109,6 +109,12 @@ class OctreeConvBase:
|
|
|
109
109
|
r''' Peforms the forward pass of octree-based convolution.
|
|
110
110
|
'''
|
|
111
111
|
|
|
112
|
+
# Type check
|
|
113
|
+
if data.dtype != out.dtype:
|
|
114
|
+
data = data.to(out.dtype)
|
|
115
|
+
if weights.dtype != out.dtype:
|
|
116
|
+
weights = weights.to(out.dtype)
|
|
117
|
+
|
|
112
118
|
# Initialize the buffer
|
|
113
119
|
buffer = data.new_empty(self.buffer_shape)
|
|
114
120
|
|
|
@@ -139,6 +145,12 @@ class OctreeConvBase:
|
|
|
139
145
|
r''' Performs the backward pass of octree-based convolution.
|
|
140
146
|
'''
|
|
141
147
|
|
|
148
|
+
# Type check
|
|
149
|
+
if grad.dtype != out.dtype:
|
|
150
|
+
grad = grad.to(out.dtype)
|
|
151
|
+
if weights.dtype != out.dtype:
|
|
152
|
+
weights = weights.to(out.dtype)
|
|
153
|
+
|
|
142
154
|
# Loop over each sub-matrix
|
|
143
155
|
for i in range(self.buffer_n):
|
|
144
156
|
start = i * self.buffer_h
|
|
@@ -165,6 +177,12 @@ class OctreeConvBase:
|
|
|
165
177
|
r''' Computes the gradient of the weight matrix.
|
|
166
178
|
'''
|
|
167
179
|
|
|
180
|
+
# Type check
|
|
181
|
+
if data.dtype != out.dtype:
|
|
182
|
+
data = data.to(out.dtype)
|
|
183
|
+
if grad.dtype != out.dtype:
|
|
184
|
+
grad = grad.to(out.dtype)
|
|
185
|
+
|
|
168
186
|
# Record the shape of out
|
|
169
187
|
out_shape = out.shape
|
|
170
188
|
out = out.flatten(0, 1)
|
|
@@ -32,6 +32,12 @@ class OctreeDWConvBase(OctreeConvBase):
|
|
|
32
32
|
r''' Peforms the forward pass of octree-based convolution.
|
|
33
33
|
'''
|
|
34
34
|
|
|
35
|
+
# Type check
|
|
36
|
+
if data.dtype != out.dtype:
|
|
37
|
+
data = data.to(out.dtype)
|
|
38
|
+
if weights.dtype != out.dtype:
|
|
39
|
+
weights = weights.to(out.dtype)
|
|
40
|
+
|
|
35
41
|
# Initialize the buffer
|
|
36
42
|
buffer = data.new_empty(self.buffer_shape)
|
|
37
43
|
|
|
@@ -62,6 +68,12 @@ class OctreeDWConvBase(OctreeConvBase):
|
|
|
62
68
|
r''' Performs the backward pass of octree-based convolution.
|
|
63
69
|
'''
|
|
64
70
|
|
|
71
|
+
# Type check
|
|
72
|
+
if grad.dtype != out.dtype:
|
|
73
|
+
grad = grad.to(out.dtype)
|
|
74
|
+
if weights.dtype != out.dtype:
|
|
75
|
+
weights = weights.to(out.dtype)
|
|
76
|
+
|
|
65
77
|
# Loop over each sub-matrix
|
|
66
78
|
for i in range(self.buffer_n):
|
|
67
79
|
start = i * self.buffer_h
|
|
@@ -88,6 +100,12 @@ class OctreeDWConvBase(OctreeConvBase):
|
|
|
88
100
|
r''' Computes the gradient of the weight matrix.
|
|
89
101
|
'''
|
|
90
102
|
|
|
103
|
+
# Type check
|
|
104
|
+
if data.dtype != out.dtype:
|
|
105
|
+
data = data.to(out.dtype)
|
|
106
|
+
if grad.dtype != out.dtype:
|
|
107
|
+
grad = grad.to(out.dtype)
|
|
108
|
+
|
|
91
109
|
# Record the shape of out
|
|
92
110
|
out_shape = out.shape
|
|
93
111
|
out = out.flatten(0, 1)
|
|
@@ -7,6 +7,7 @@
|
|
|
7
7
|
|
|
8
8
|
import torch
|
|
9
9
|
import torch.nn
|
|
10
|
+
from typing import Optional
|
|
10
11
|
|
|
11
12
|
from ocnn.octree import Octree
|
|
12
13
|
from ocnn.utils import scatter_add
|
|
@@ -19,15 +20,19 @@ class OctreeGroupNorm(torch.nn.Module):
|
|
|
19
20
|
r''' An group normalization layer for the octree.
|
|
20
21
|
'''
|
|
21
22
|
|
|
22
|
-
def __init__(self, in_channels: int, group: int, nempty: bool = False
|
|
23
|
+
def __init__(self, in_channels: int, group: int, nempty: bool = False,
|
|
24
|
+
min_group_channels: int = 4):
|
|
23
25
|
super().__init__()
|
|
24
26
|
self.eps = 1e-5
|
|
25
27
|
self.nempty = nempty
|
|
26
28
|
self.group = group
|
|
27
29
|
self.in_channels = in_channels
|
|
30
|
+
self.min_group_channels = min_group_channels
|
|
31
|
+
if self.min_group_channels * self.group > in_channels:
|
|
32
|
+
self.group = in_channels // self.min_group_channels
|
|
28
33
|
|
|
29
|
-
assert in_channels % group == 0
|
|
30
|
-
self.channels_per_group = in_channels // group
|
|
34
|
+
assert in_channels % self.group == 0
|
|
35
|
+
self.channels_per_group = in_channels // self.group
|
|
31
36
|
|
|
32
37
|
self.weights = torch.nn.Parameter(torch.Tensor(1, in_channels))
|
|
33
38
|
self.bias = torch.nn.Parameter(torch.Tensor(1, in_channels))
|
|
@@ -71,8 +76,8 @@ class OctreeGroupNorm(torch.nn.Module):
|
|
|
71
76
|
return tensor
|
|
72
77
|
|
|
73
78
|
def extra_repr(self) -> str:
|
|
74
|
-
return ('in_channels={}, group={}, nempty={}').format(
|
|
75
|
-
|
|
79
|
+
return ('in_channels={}, group={}, nempty={}, min_group_channels={}').format(
|
|
80
|
+
self.in_channels, self.group, self.nempty, self.min_group_channels)
|
|
76
81
|
|
|
77
82
|
|
|
78
83
|
class OctreeInstanceNorm(OctreeGroupNorm):
|
|
@@ -80,7 +85,42 @@ class OctreeInstanceNorm(OctreeGroupNorm):
|
|
|
80
85
|
'''
|
|
81
86
|
|
|
82
87
|
def __init__(self, in_channels: int, nempty: bool = False):
|
|
83
|
-
super().__init__(in_channels=in_channels, group=in_channels, nempty=nempty
|
|
88
|
+
super().__init__(in_channels=in_channels, group=in_channels, nempty=nempty,
|
|
89
|
+
min_group_channels=1) # NOTE: group=in_channels
|
|
84
90
|
|
|
85
91
|
def extra_repr(self) -> str:
|
|
86
92
|
return ('in_channels={}, nempty={}').format(self.in_channels, self.nempty)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class OctreeNorm(torch.nn.Module):
|
|
96
|
+
r''' A normalization layer for the octree. It encapsulates octree-based batch,
|
|
97
|
+
group and instance normalization.
|
|
98
|
+
'''
|
|
99
|
+
|
|
100
|
+
def __init__(self, in_channels: int, norm_type: str = 'batch_norm',
|
|
101
|
+
group: int = 32, min_group_channels: int = 4):
|
|
102
|
+
super().__init__()
|
|
103
|
+
self.in_channels = in_channels
|
|
104
|
+
self.norm_type = norm_type
|
|
105
|
+
self.group = group
|
|
106
|
+
self.min_group_channels = min_group_channels
|
|
107
|
+
|
|
108
|
+
if self.norm_type == 'batch_norm':
|
|
109
|
+
self.norm = torch.nn.BatchNorm1d(in_channels)
|
|
110
|
+
elif self.norm_type == 'group_norm':
|
|
111
|
+
self.norm = OctreeGroupNorm(in_channels, group, min_group_channels)
|
|
112
|
+
elif self.norm_type == 'instance_norm':
|
|
113
|
+
self.norm = OctreeInstanceNorm(in_channels)
|
|
114
|
+
else:
|
|
115
|
+
raise ValueError
|
|
116
|
+
|
|
117
|
+
def forward(self, x: torch.Tensor, octree: Optional[Octree] = None,
|
|
118
|
+
depth: Optional[int] = None):
|
|
119
|
+
if self.norm_type == 'batch_norm':
|
|
120
|
+
output = self.norm(x)
|
|
121
|
+
elif (self.norm_type == 'group_norm' or
|
|
122
|
+
self.norm_type == 'instance_norm'):
|
|
123
|
+
output = self.norm(x, octree, depth)
|
|
124
|
+
else:
|
|
125
|
+
raise ValueError
|
|
126
|
+
return output
|
|
@@ -63,14 +63,14 @@ class Octree:
|
|
|
63
63
|
|
|
64
64
|
# octree node numbers in each octree layers.
|
|
65
65
|
# TODO: decide whether to settle them to 'gpu' or not?
|
|
66
|
-
self.nnum = torch.zeros(num, dtype=torch.
|
|
67
|
-
self.nnum_nempty = torch.zeros(num, dtype=torch.
|
|
66
|
+
self.nnum = torch.zeros(num, dtype=torch.long)
|
|
67
|
+
self.nnum_nempty = torch.zeros(num, dtype=torch.long)
|
|
68
68
|
|
|
69
69
|
# the following properties are valid after `merge_octrees`.
|
|
70
70
|
# TODO: make them valid after `octree_grow`, `octree_split` and `build_octree`
|
|
71
71
|
batch_size = self.batch_size
|
|
72
|
-
self.batch_nnum = torch.zeros(num, batch_size, dtype=torch.
|
|
73
|
-
self.batch_nnum_nempty = torch.zeros(num, batch_size, dtype=torch.
|
|
72
|
+
self.batch_nnum = torch.zeros(num, batch_size, dtype=torch.long)
|
|
73
|
+
self.batch_nnum_nempty = torch.zeros(num, batch_size, dtype=torch.long)
|
|
74
74
|
|
|
75
75
|
# construct the look up tables for neighborhood searching
|
|
76
76
|
device = self.device
|
|
@@ -274,7 +274,7 @@ class Octree:
|
|
|
274
274
|
children[0] = 0
|
|
275
275
|
|
|
276
276
|
# update octree
|
|
277
|
-
self.children[depth] = children
|
|
277
|
+
self.children[depth] = children.int()
|
|
278
278
|
self.nnum_nempty[depth] = nnum_nempty
|
|
279
279
|
|
|
280
280
|
def octree_grow(self, depth: int, update_neigh: bool = True):
|
|
@@ -290,7 +290,7 @@ class Octree:
|
|
|
290
290
|
# node number
|
|
291
291
|
nnum = self.nnum_nempty[depth-1] * 8
|
|
292
292
|
self.nnum[depth] = nnum
|
|
293
|
-
self.nnum_nempty[depth] = nnum
|
|
293
|
+
self.nnum_nempty[depth] = nnum # initialize self.nnum_nempty
|
|
294
294
|
|
|
295
295
|
# update keys
|
|
296
296
|
key = self.key(depth-1, nempty=True)
|
|
@@ -326,7 +326,7 @@ class Octree:
|
|
|
326
326
|
xyz = xyz.view(-1, 3) # (N*27, 3)
|
|
327
327
|
neigh = xyz2key(xyz[:, 0], xyz[:, 1], xyz[:, 2], depth=depth)
|
|
328
328
|
|
|
329
|
-
bs = torch.arange(self.batch_size, dtype=torch.
|
|
329
|
+
bs = torch.arange(self.batch_size, dtype=torch.long, device=device)
|
|
330
330
|
neigh = neigh + bs.unsqueeze(1) * nnum # (N*27,) + (B, 1) -> (B, N*27)
|
|
331
331
|
|
|
332
332
|
bound = 1 << depth
|
|
@@ -383,9 +383,10 @@ class Octree:
|
|
|
383
383
|
# I choose `torch.bucketize` here because it has fewer dimension checks,
|
|
384
384
|
# resulting in slightly better performance according to the docs of
|
|
385
385
|
# pytorch-1.9.1, since `key` is always 1-D sorted sequence.
|
|
386
|
+
# https://pytorch.org/docs/1.9.1/generated/torch.searchsorted.html
|
|
386
387
|
idx = torch.bucketize(query, key)
|
|
387
388
|
|
|
388
|
-
valid = idx < key.shape[0] # valid if
|
|
389
|
+
valid = idx < key.shape[0] # valid if in-bound
|
|
389
390
|
found = key[idx[valid]] == query[valid]
|
|
390
391
|
valid[valid.clone()] = found # valid if found
|
|
391
392
|
idx[valid.logical_not()] = -1 # set to -1 if invalid
|
|
@@ -498,7 +499,7 @@ class Octree:
|
|
|
498
499
|
# normalize xyz to [-1, 1] since the average points are in range [0, 2^d]
|
|
499
500
|
if rescale:
|
|
500
501
|
scale = 2 ** (1 - depth)
|
|
501
|
-
xyz =
|
|
502
|
+
xyz = xyz * scale - 1.0
|
|
502
503
|
|
|
503
504
|
# construct Points
|
|
504
505
|
out = Points(xyz, self.normals[depth], self.features[depth],
|
|
@@ -56,11 +56,16 @@ class Points:
|
|
|
56
56
|
assert self.features.dim() == 2
|
|
57
57
|
assert self.features.size(0) == self.points.size(0)
|
|
58
58
|
if self.labels is not None:
|
|
59
|
-
assert self.labels.dim() == 2
|
|
59
|
+
assert self.labels.dim() == 2 or self.labels.dim() == 1
|
|
60
60
|
assert self.labels.size(0) == self.points.size(0)
|
|
61
|
+
if self.labels.dim() == 1:
|
|
62
|
+
self.labels = self.labels.unsqueeze(1)
|
|
61
63
|
if self.batch_id is not None:
|
|
62
|
-
assert self.batch_id.dim() == 2
|
|
64
|
+
assert self.batch_id.dim() == 2 or self.batch_id.dim() == 1
|
|
63
65
|
assert self.batch_id.size(0) == self.points.size(0)
|
|
66
|
+
assert self.batch_id.size(1) == 1
|
|
67
|
+
if self.batch_id.dim() == 1:
|
|
68
|
+
self.batch_id = self.batch_id.unsqueeze(1)
|
|
64
69
|
|
|
65
70
|
@property
|
|
66
71
|
def npt(self):
|
|
@@ -173,7 +173,7 @@ def resize_with_last_val(list_in: list, num: int = 3):
|
|
|
173
173
|
'''
|
|
174
174
|
|
|
175
175
|
assert (type(list_in) is list and len(list_in) < num + 1)
|
|
176
|
-
for
|
|
176
|
+
for _ in range(len(list_in), num):
|
|
177
177
|
list_in.append(list_in[-1])
|
|
178
178
|
return list_in
|
|
179
179
|
|
|
@@ -186,15 +186,18 @@ def list2str(list_in: list):
|
|
|
186
186
|
return ''.join(out)
|
|
187
187
|
|
|
188
188
|
|
|
189
|
-
def build_example_octree(depth: int = 5, full_depth: int = 2):
|
|
190
|
-
r''' Builds an example octree on CPU from 3 points.
|
|
189
|
+
def build_example_octree(depth: int = 5, full_depth: int = 2, pt_num: int = 3):
|
|
190
|
+
r''' Builds an example octree on CPU from at most 3 points.
|
|
191
191
|
'''
|
|
192
192
|
# initialize the point cloud
|
|
193
193
|
points = torch.Tensor([[-1, -1, -1], [0, 0, -1], [0.0625, 0.0625, -1]])
|
|
194
194
|
normals = torch.Tensor([[1, 0, 0], [-1, 0, 0], [0, 1, 0]])
|
|
195
195
|
features = torch.Tensor([[1, -1], [2, -2], [3, -3]])
|
|
196
196
|
labels = torch.Tensor([[0], [2], [2]])
|
|
197
|
-
|
|
197
|
+
|
|
198
|
+
assert pt_num <= 3 and pt_num > 0
|
|
199
|
+
point_cloud = ocnn.octree.Points(
|
|
200
|
+
points[:pt_num], normals[:pt_num], features[:pt_num], labels[:pt_num])
|
|
198
201
|
|
|
199
202
|
# build octree
|
|
200
203
|
octree = ocnn.octree.Octree(depth, full_depth)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: ocnn
|
|
3
|
-
Version: 2.2.
|
|
3
|
+
Version: 2.2.4
|
|
4
4
|
Summary: Octree-based Sparse Convolutional Neural Networks
|
|
5
5
|
Home-page: https://github.com/octree-nn/ocnn-pytorch
|
|
6
6
|
Author: Peng-Shuai Wang
|
|
@@ -23,6 +23,7 @@ Requires-Dist: packaging
|
|
|
23
23
|
|
|
24
24
|
[](https://ocnn-pytorch.readthedocs.io/en/latest/?badge=latest)
|
|
25
25
|
[](https://pepy.tech/project/ocnn)
|
|
26
|
+
[](https://pepy.tech/project/ocnn)
|
|
26
27
|
[](https://pypi.org/project/ocnn/)
|
|
27
28
|
|
|
28
29
|
This repository contains the **pure PyTorch**-based implementation of
|
|
@@ -43,14 +44,14 @@ The key difference is that our O-CNN uses the `octree` to index the sparse
|
|
|
43
44
|
voxels, while these 3 works use the `Hash Table`.
|
|
44
45
|
|
|
45
46
|
Our O-CNN is published in SIGGRAPH 2017, H-CNN is published in TVCG 2018,
|
|
46
|
-
SparseConvNet is published in CVPR 2018, and MinkowskiNet is published in
|
|
47
|
+
SparseConvNet is published in CVPR 2018, and MinkowskiNet is published in
|
|
47
48
|
CVPR 2019. Actually, our O-CNN was submitted to SIGGRAPH in the end of 2016 and
|
|
48
49
|
was officially accepted in March, 2017. The camera-ready version of our O-CNN was
|
|
49
50
|
submitted to SIGGRAPH in April, 2017. We just did not post our paper on Arxiv
|
|
50
51
|
during the review process of SIGGRAPH. Therefore, **the idea of constraining CNN
|
|
51
52
|
computation into sparse non-emtpry voxels is first proposed by our O-CNN**.
|
|
52
53
|
Currently, this type of 3D convolution is known as Sparse Convolution in the
|
|
53
|
-
research community.
|
|
54
|
+
research community.
|
|
54
55
|
|
|
55
56
|
## Key benefits of ocnn-pytorch
|
|
56
57
|
|
|
@@ -65,3 +66,15 @@ research community.
|
|
|
65
66
|
training settings, MinkowskiNet 0.4.3 takes 60 hours and MinkowskiNet 0.5.4
|
|
66
67
|
takes 30 hours.
|
|
67
68
|
|
|
69
|
+
## Citation
|
|
70
|
+
|
|
71
|
+
```bibtex
|
|
72
|
+
@article {Wang-2017-ocnn,
|
|
73
|
+
title = {{O-CNN}: Octree-based Convolutional Neural Networksfor {3D} Shape Analysis},
|
|
74
|
+
author = {Wang, Peng-Shuai and Liu, Yang and Guo, Yu-Xiao and Sun, Chun-Yu and Tong, Xin},
|
|
75
|
+
journal = {ACM Transactions on Graphics (SIGGRAPH)},
|
|
76
|
+
volume = {36},
|
|
77
|
+
number = {4},
|
|
78
|
+
year = {2017},
|
|
79
|
+
}
|
|
80
|
+
```
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|