ocnn 2.2.2__py3-none-any.whl → 2.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ocnn/__init__.py +1 -1
- ocnn/models/lenet.py +1 -1
- ocnn/nn/__init__.py +2 -1
- ocnn/nn/octree_conv.py +18 -0
- ocnn/nn/octree_dwconv.py +18 -0
- ocnn/octree/octree.py +2 -2
- ocnn/octree/points.py +7 -2
- ocnn/utils.py +7 -4
- {ocnn-2.2.2.dist-info → ocnn-2.2.3.dist-info}/METADATA +15 -3
- {ocnn-2.2.2.dist-info → ocnn-2.2.3.dist-info}/RECORD +13 -13
- {ocnn-2.2.2.dist-info → ocnn-2.2.3.dist-info}/WHEEL +1 -1
- {ocnn-2.2.2.dist-info → ocnn-2.2.3.dist-info}/LICENSE +0 -0
- {ocnn-2.2.2.dist-info → ocnn-2.2.3.dist-info}/top_level.txt +0 -0
ocnn/__init__.py
CHANGED
ocnn/models/lenet.py
CHANGED
|
@@ -26,7 +26,7 @@ class LeNet(torch.nn.Module):
|
|
|
26
26
|
self.convs = torch.nn.ModuleList([ocnn.modules.OctreeConvBnRelu(
|
|
27
27
|
channels[i], channels[i+1], nempty=nempty) for i in range(stages)])
|
|
28
28
|
self.pools = torch.nn.ModuleList([ocnn.nn.OctreeMaxPool(
|
|
29
|
-
nempty) for
|
|
29
|
+
nempty) for _ in range(stages)])
|
|
30
30
|
self.octree2voxel = ocnn.nn.Octree2Voxel(self.nempty)
|
|
31
31
|
self.header = torch.nn.Sequential(
|
|
32
32
|
torch.nn.Dropout(p=0.5), # drop1
|
ocnn/nn/__init__.py
CHANGED
|
@@ -15,6 +15,7 @@ from .octree_pool import (octree_max_pool, OctreeMaxPool,
|
|
|
15
15
|
octree_global_pool, OctreeGlobalPool,
|
|
16
16
|
octree_avg_pool, OctreeAvgPool,)
|
|
17
17
|
from .octree_conv import OctreeConv, OctreeDeconv
|
|
18
|
+
from .octree_gconv import OctreeGroupConv
|
|
18
19
|
from .octree_dwconv import OctreeDWConv
|
|
19
20
|
from .octree_norm import OctreeBatchNorm, OctreeGroupNorm, OctreeInstanceNorm
|
|
20
21
|
from .octree_drop import OctreeDropPath
|
|
@@ -32,7 +33,7 @@ __all__ = [
|
|
|
32
33
|
'OctreeMaxPool', 'OctreeMaxUnpool',
|
|
33
34
|
'OctreeGlobalPool', 'OctreeAvgPool',
|
|
34
35
|
'OctreeConv', 'OctreeDeconv',
|
|
35
|
-
'OctreeDWConv',
|
|
36
|
+
'OctreeGroupConv', 'OctreeDWConv',
|
|
36
37
|
'OctreeInterp', 'OctreeUpsample',
|
|
37
38
|
'OctreeInstanceNorm', 'OctreeBatchNorm', 'OctreeGroupNorm',
|
|
38
39
|
'OctreeDropPath',
|
ocnn/nn/octree_conv.py
CHANGED
|
@@ -109,6 +109,12 @@ class OctreeConvBase:
|
|
|
109
109
|
r''' Peforms the forward pass of octree-based convolution.
|
|
110
110
|
'''
|
|
111
111
|
|
|
112
|
+
# Type check
|
|
113
|
+
if data.dtype != out.dtype:
|
|
114
|
+
data = data.to(out.dtype)
|
|
115
|
+
if weights.dtype != out.dtype:
|
|
116
|
+
weights = weights.to(out.dtype)
|
|
117
|
+
|
|
112
118
|
# Initialize the buffer
|
|
113
119
|
buffer = data.new_empty(self.buffer_shape)
|
|
114
120
|
|
|
@@ -139,6 +145,12 @@ class OctreeConvBase:
|
|
|
139
145
|
r''' Performs the backward pass of octree-based convolution.
|
|
140
146
|
'''
|
|
141
147
|
|
|
148
|
+
# Type check
|
|
149
|
+
if grad.dtype != out.dtype:
|
|
150
|
+
grad = grad.to(out.dtype)
|
|
151
|
+
if weights.dtype != out.dtype:
|
|
152
|
+
weights = weights.to(out.dtype)
|
|
153
|
+
|
|
142
154
|
# Loop over each sub-matrix
|
|
143
155
|
for i in range(self.buffer_n):
|
|
144
156
|
start = i * self.buffer_h
|
|
@@ -165,6 +177,12 @@ class OctreeConvBase:
|
|
|
165
177
|
r''' Computes the gradient of the weight matrix.
|
|
166
178
|
'''
|
|
167
179
|
|
|
180
|
+
# Type check
|
|
181
|
+
if data.dtype != out.dtype:
|
|
182
|
+
data = data.to(out.dtype)
|
|
183
|
+
if grad.dtype != out.dtype:
|
|
184
|
+
grad = grad.to(out.dtype)
|
|
185
|
+
|
|
168
186
|
# Record the shape of out
|
|
169
187
|
out_shape = out.shape
|
|
170
188
|
out = out.flatten(0, 1)
|
ocnn/nn/octree_dwconv.py
CHANGED
|
@@ -32,6 +32,12 @@ class OctreeDWConvBase(OctreeConvBase):
|
|
|
32
32
|
r''' Peforms the forward pass of octree-based convolution.
|
|
33
33
|
'''
|
|
34
34
|
|
|
35
|
+
# Type check
|
|
36
|
+
if data.dtype != out.dtype:
|
|
37
|
+
data = data.to(out.dtype)
|
|
38
|
+
if weights.dtype != out.dtype:
|
|
39
|
+
weights = weights.to(out.dtype)
|
|
40
|
+
|
|
35
41
|
# Initialize the buffer
|
|
36
42
|
buffer = data.new_empty(self.buffer_shape)
|
|
37
43
|
|
|
@@ -62,6 +68,12 @@ class OctreeDWConvBase(OctreeConvBase):
|
|
|
62
68
|
r''' Performs the backward pass of octree-based convolution.
|
|
63
69
|
'''
|
|
64
70
|
|
|
71
|
+
# Type check
|
|
72
|
+
if grad.dtype != out.dtype:
|
|
73
|
+
grad = grad.to(out.dtype)
|
|
74
|
+
if weights.dtype != out.dtype:
|
|
75
|
+
weights = weights.to(out.dtype)
|
|
76
|
+
|
|
65
77
|
# Loop over each sub-matrix
|
|
66
78
|
for i in range(self.buffer_n):
|
|
67
79
|
start = i * self.buffer_h
|
|
@@ -88,6 +100,12 @@ class OctreeDWConvBase(OctreeConvBase):
|
|
|
88
100
|
r''' Computes the gradient of the weight matrix.
|
|
89
101
|
'''
|
|
90
102
|
|
|
103
|
+
# Type check
|
|
104
|
+
if data.dtype != out.dtype:
|
|
105
|
+
data = data.to(out.dtype)
|
|
106
|
+
if grad.dtype != out.dtype:
|
|
107
|
+
grad = grad.to(out.dtype)
|
|
108
|
+
|
|
91
109
|
# Record the shape of out
|
|
92
110
|
out_shape = out.shape
|
|
93
111
|
out = out.flatten(0, 1)
|
ocnn/octree/octree.py
CHANGED
|
@@ -274,7 +274,7 @@ class Octree:
|
|
|
274
274
|
children[0] = 0
|
|
275
275
|
|
|
276
276
|
# update octree
|
|
277
|
-
self.children[depth] = children
|
|
277
|
+
self.children[depth] = children.int()
|
|
278
278
|
self.nnum_nempty[depth] = nnum_nempty
|
|
279
279
|
|
|
280
280
|
def octree_grow(self, depth: int, update_neigh: bool = True):
|
|
@@ -498,7 +498,7 @@ class Octree:
|
|
|
498
498
|
# normalize xyz to [-1, 1] since the average points are in range [0, 2^d]
|
|
499
499
|
if rescale:
|
|
500
500
|
scale = 2 ** (1 - depth)
|
|
501
|
-
xyz =
|
|
501
|
+
xyz = xyz * scale - 1.0
|
|
502
502
|
|
|
503
503
|
# construct Points
|
|
504
504
|
out = Points(xyz, self.normals[depth], self.features[depth],
|
ocnn/octree/points.py
CHANGED
|
@@ -56,11 +56,16 @@ class Points:
|
|
|
56
56
|
assert self.features.dim() == 2
|
|
57
57
|
assert self.features.size(0) == self.points.size(0)
|
|
58
58
|
if self.labels is not None:
|
|
59
|
-
assert self.labels.dim() == 2
|
|
59
|
+
assert self.labels.dim() == 2 or self.labels.dim() == 1
|
|
60
60
|
assert self.labels.size(0) == self.points.size(0)
|
|
61
|
+
if self.labels.dim() == 1:
|
|
62
|
+
self.labels = self.labels.unsqueeze(1)
|
|
61
63
|
if self.batch_id is not None:
|
|
62
|
-
assert self.batch_id.dim() == 2
|
|
64
|
+
assert self.batch_id.dim() == 2 or self.batch_id.dim() == 1
|
|
63
65
|
assert self.batch_id.size(0) == self.points.size(0)
|
|
66
|
+
assert self.batch_id.size(1) == 1
|
|
67
|
+
if self.batch_id.dim() == 1:
|
|
68
|
+
self.batch_id = self.batch_id.unsqueeze(1)
|
|
64
69
|
|
|
65
70
|
@property
|
|
66
71
|
def npt(self):
|
ocnn/utils.py
CHANGED
|
@@ -173,7 +173,7 @@ def resize_with_last_val(list_in: list, num: int = 3):
|
|
|
173
173
|
'''
|
|
174
174
|
|
|
175
175
|
assert (type(list_in) is list and len(list_in) < num + 1)
|
|
176
|
-
for
|
|
176
|
+
for _ in range(len(list_in), num):
|
|
177
177
|
list_in.append(list_in[-1])
|
|
178
178
|
return list_in
|
|
179
179
|
|
|
@@ -186,15 +186,18 @@ def list2str(list_in: list):
|
|
|
186
186
|
return ''.join(out)
|
|
187
187
|
|
|
188
188
|
|
|
189
|
-
def build_example_octree(depth: int = 5, full_depth: int = 2):
|
|
190
|
-
r''' Builds an example octree on CPU from 3 points.
|
|
189
|
+
def build_example_octree(depth: int = 5, full_depth: int = 2, pt_num: int = 3):
|
|
190
|
+
r''' Builds an example octree on CPU from at most 3 points.
|
|
191
191
|
'''
|
|
192
192
|
# initialize the point cloud
|
|
193
193
|
points = torch.Tensor([[-1, -1, -1], [0, 0, -1], [0.0625, 0.0625, -1]])
|
|
194
194
|
normals = torch.Tensor([[1, 0, 0], [-1, 0, 0], [0, 1, 0]])
|
|
195
195
|
features = torch.Tensor([[1, -1], [2, -2], [3, -3]])
|
|
196
196
|
labels = torch.Tensor([[0], [2], [2]])
|
|
197
|
-
|
|
197
|
+
|
|
198
|
+
assert pt_num <= 3 and pt_num > 0
|
|
199
|
+
point_cloud = ocnn.octree.Points(
|
|
200
|
+
points[:pt_num], normals[:pt_num], features[:pt_num], labels[:pt_num])
|
|
198
201
|
|
|
199
202
|
# build octree
|
|
200
203
|
octree = ocnn.octree.Octree(depth, full_depth)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: ocnn
|
|
3
|
-
Version: 2.2.
|
|
3
|
+
Version: 2.2.3
|
|
4
4
|
Summary: Octree-based Sparse Convolutional Neural Networks
|
|
5
5
|
Home-page: https://github.com/octree-nn/ocnn-pytorch
|
|
6
6
|
Author: Peng-Shuai Wang
|
|
@@ -43,14 +43,14 @@ The key difference is that our O-CNN uses the `octree` to index the sparse
|
|
|
43
43
|
voxels, while these 3 works use the `Hash Table`.
|
|
44
44
|
|
|
45
45
|
Our O-CNN is published in SIGGRAPH 2017, H-CNN is published in TVCG 2018,
|
|
46
|
-
SparseConvNet is published in CVPR 2018, and MinkowskiNet is published in
|
|
46
|
+
SparseConvNet is published in CVPR 2018, and MinkowskiNet is published in
|
|
47
47
|
CVPR 2019. Actually, our O-CNN was submitted to SIGGRAPH in the end of 2016 and
|
|
48
48
|
was officially accepted in March, 2017. The camera-ready version of our O-CNN was
|
|
49
49
|
submitted to SIGGRAPH in April, 2017. We just did not post our paper on Arxiv
|
|
50
50
|
during the review process of SIGGRAPH. Therefore, **the idea of constraining CNN
|
|
51
51
|
computation into sparse non-emtpry voxels is first proposed by our O-CNN**.
|
|
52
52
|
Currently, this type of 3D convolution is known as Sparse Convolution in the
|
|
53
|
-
research community.
|
|
53
|
+
research community.
|
|
54
54
|
|
|
55
55
|
## Key benefits of ocnn-pytorch
|
|
56
56
|
|
|
@@ -65,3 +65,15 @@ research community.
|
|
|
65
65
|
training settings, MinkowskiNet 0.4.3 takes 60 hours and MinkowskiNet 0.5.4
|
|
66
66
|
takes 30 hours.
|
|
67
67
|
|
|
68
|
+
## Citation
|
|
69
|
+
|
|
70
|
+
```bibtex
|
|
71
|
+
@article {Wang-2017-ocnn,
|
|
72
|
+
title = {{O-CNN}: Octree-based Convolutional Neural Networksfor {3D} Shape Analysis},
|
|
73
|
+
author = {Wang, Peng-Shuai and Liu, Yang and Guo, Yu-Xiao and Sun, Chun-Yu and Tong, Xin},
|
|
74
|
+
journal = {ACM Transactions on Graphics (SIGGRAPH)},
|
|
75
|
+
volume = {36},
|
|
76
|
+
number = {4},
|
|
77
|
+
year = {2017},
|
|
78
|
+
}
|
|
79
|
+
```
|
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
ocnn/__init__.py,sha256=
|
|
1
|
+
ocnn/__init__.py,sha256=kCGigMn_30MohB49Hwy0ZBXf5HTUA_i-QOjckrM48Nc,582
|
|
2
2
|
ocnn/dataset.py,sha256=wvclvjlZs9qTeMXWLaO32K5d1VVY9XHSNuVVJEpVeeo,5266
|
|
3
|
-
ocnn/utils.py,sha256=
|
|
3
|
+
ocnn/utils.py,sha256=XhykveOjHoQd94gjJ5-opzXs-9MOCAzZ34ArZ8mG4sE,6726
|
|
4
4
|
ocnn/models/__init__.py,sha256=F9PJRhOPHc1OrwkqcfywEBW0J6jmVW7-IHgWjGpY15U,724
|
|
5
5
|
ocnn/models/autoencoder.py,sha256=TjOet3dbZLexz-PvZdeV0mbIuGpUAeSH6KWU_z_s-d8,5906
|
|
6
6
|
ocnn/models/hrnet.py,sha256=9W2fi7Fuw0JXDBiZtEoUW2K7ghtpWUm_BWd-mKoHLY0,6684
|
|
7
7
|
ocnn/models/image2shape.py,sha256=5djcOHJh2SQCwd5XdLPeL5vlQDNWnRU3tJY3ojRI8aQ,4589
|
|
8
|
-
ocnn/models/lenet.py,sha256=
|
|
8
|
+
ocnn/models/lenet.py,sha256=ujVBxnn8AKiIKqB4WHxuK728oJe4TZdqZUTTNy8y3zE,1754
|
|
9
9
|
ocnn/models/ounet.py,sha256=sGGVHqF1nKw798vyththxp7-P3wHlsDUO39hbQpUVTw,3303
|
|
10
10
|
ocnn/models/resnet.py,sha256=9gZKbhFituqPJHCm-bp1xTtHKZ6wu77s31a0q7RtXoQ,2029
|
|
11
11
|
ocnn/models/segnet.py,sha256=VfZf8gBMSPgO5m8Agsfccw9snTrI4LGh0SqxhLZ1F8s,2575
|
|
@@ -13,24 +13,24 @@ ocnn/models/unet.py,sha256=1FZbTvmWg6sMYkcZyNxcSr_qN6bfOOag-tIwAoqIPKU,4123
|
|
|
13
13
|
ocnn/modules/__init__.py,sha256=BAEZybvtwDQf7yPZpGcptH1Oxf654Alq87G9D0vKw-E,820
|
|
14
14
|
ocnn/modules/modules.py,sha256=7VlBsbwN49J8Xea0DXcFHVhN8POxhzC3FNMY3_JHukM,6207
|
|
15
15
|
ocnn/modules/resblocks.py,sha256=Xh5EH6aSx8zcLk3D4m4lZiWDejLMu5AQDH7ZDI4p23A,4548
|
|
16
|
-
ocnn/nn/__init__.py,sha256=
|
|
16
|
+
ocnn/nn/__init__.py,sha256=6M3GEbyepHCU7l_QU1blSI98M0X4d3jbCsXmVGtXYj0,1769
|
|
17
17
|
ocnn/nn/octree2col.py,sha256=07BGGJD0je0VD-VdS_aKtDo7gWKNWdgojeL1a2n4VRQ,2137
|
|
18
18
|
ocnn/nn/octree2vox.py,sha256=QgyxxZvRvw2taHFWCvcGyCFAnFC6D6nfw71HrvwN3PI,1524
|
|
19
19
|
ocnn/nn/octree_align.py,sha256=Y12GBKS-F3JtRNaDdDcBFDBvhndoRxeMOVua4RB5HZE,1696
|
|
20
|
-
ocnn/nn/octree_conv.py,sha256=
|
|
20
|
+
ocnn/nn/octree_conv.py,sha256=a6-lCkJR5NJL7lBKZNWklzQUnDLj6hIysGnr45_xCQk,14942
|
|
21
21
|
ocnn/nn/octree_drop.py,sha256=croMHtk0JScDT0nLpdmbiMnkM_b5uVAz6sOEUcta6sY,1963
|
|
22
|
-
ocnn/nn/octree_dwconv.py,sha256=
|
|
22
|
+
ocnn/nn/octree_dwconv.py,sha256=cIghi7zyMUGhTd6QsU4PS6K2VlwXaYeLNfZFt2xuVi4,7350
|
|
23
23
|
ocnn/nn/octree_gconv.py,sha256=Ogkn7wE49dDdP0X_aBthNfJXMsqcRn3ufJuYQtEEjUI,2928
|
|
24
24
|
ocnn/nn/octree_interp.py,sha256=yQVjiKMNLU1XakPeYjW5ArMsw2fxFRZdh90Nn3cWQuE,7108
|
|
25
25
|
ocnn/nn/octree_norm.py,sha256=Mbn28Hv-CEWt2WA0Pdhj2p127m10vwhepzCJVjiZYMI,2976
|
|
26
26
|
ocnn/nn/octree_pad.py,sha256=suV6Ftb-UlUuoJzKdQ9DCP9oqVQJq7vXb9_6hq-kUk4,1323
|
|
27
27
|
ocnn/nn/octree_pool.py,sha256=Zn2XLk5SFl6pqMhhKIvu1uZl5ebonFSlPcKr54fOIPA,6664
|
|
28
28
|
ocnn/octree/__init__.py,sha256=vKZFc5_r6Gxg5KsPWiCZCR-umWWfWPoE7qBx4PIrUGA,630
|
|
29
|
-
ocnn/octree/octree.py,sha256=
|
|
30
|
-
ocnn/octree/points.py,sha256=
|
|
29
|
+
ocnn/octree/octree.py,sha256=PgWiECd3h49AZHrK_0GxwwAgnG9l37AcJILanM7d1-k,23713
|
|
30
|
+
ocnn/octree/points.py,sha256=7_iiN0y0g9SgV8kLzuGW20C-MFOmz5hh1aQhDhUIa0Q,11355
|
|
31
31
|
ocnn/octree/shuffled_key.py,sha256=UJZ4eKNA_7nLbf9FbEvS_3VyrAqnZCOzk1hsPtJianM,3936
|
|
32
|
-
ocnn-2.2.
|
|
33
|
-
ocnn-2.2.
|
|
34
|
-
ocnn-2.2.
|
|
35
|
-
ocnn-2.2.
|
|
36
|
-
ocnn-2.2.
|
|
32
|
+
ocnn-2.2.3.dist-info/LICENSE,sha256=YeOS0Plo8Uistv_8ZXdgddmN9GHJKnIiJ5FZ8zTW6Sw,1114
|
|
33
|
+
ocnn-2.2.3.dist-info/METADATA,sha256=Uh2UlUwXzRfM9G9d1EBRIiZARcumPCVZPY3CLjR2q8U,3682
|
|
34
|
+
ocnn-2.2.3.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
|
|
35
|
+
ocnn-2.2.3.dist-info/top_level.txt,sha256=ayZdVOnxlOke3kgzAlrRh2IEL_qOudwOaEU3xhjtpZ0,5
|
|
36
|
+
ocnn-2.2.3.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|