ocnn 2.2.8__py3-none-any.whl → 2.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. ocnn/__init__.py +24 -24
  2. ocnn/dataset.py +160 -160
  3. ocnn/models/__init__.py +29 -29
  4. ocnn/models/autoencoder.py +155 -155
  5. ocnn/models/hrnet.py +192 -192
  6. ocnn/models/image2shape.py +128 -128
  7. ocnn/models/lenet.py +46 -46
  8. ocnn/models/ounet.py +94 -94
  9. ocnn/models/resnet.py +53 -53
  10. ocnn/models/segnet.py +72 -72
  11. ocnn/models/unet.py +105 -105
  12. ocnn/modules/__init__.py +26 -26
  13. ocnn/modules/modules.py +303 -303
  14. ocnn/modules/resblocks.py +158 -158
  15. ocnn/nn/__init__.py +45 -44
  16. ocnn/nn/kernels/__init__.py +14 -0
  17. ocnn/nn/kernels/autotuner.py +416 -0
  18. ocnn/nn/kernels/config.py +67 -0
  19. ocnn/nn/kernels/conv_bwd_implicit_gemm.py +229 -0
  20. ocnn/nn/kernels/conv_bwd_implicit_gemm_splitk.py +347 -0
  21. ocnn/nn/kernels/conv_fwd_implicit_gemm.py +109 -0
  22. ocnn/nn/kernels/conv_fwd_implicit_gemm_splitk.py +150 -0
  23. ocnn/nn/kernels/utils.py +44 -0
  24. ocnn/nn/octree2col.py +53 -53
  25. ocnn/nn/octree2vox.py +50 -50
  26. ocnn/nn/octree_align.py +46 -46
  27. ocnn/nn/octree_conv.py +430 -429
  28. ocnn/nn/octree_conv_t.py +148 -0
  29. ocnn/nn/octree_drop.py +55 -55
  30. ocnn/nn/octree_dwconv.py +222 -222
  31. ocnn/nn/octree_gconv.py +79 -79
  32. ocnn/nn/octree_interp.py +196 -196
  33. ocnn/nn/octree_norm.py +126 -126
  34. ocnn/nn/octree_pad.py +39 -39
  35. ocnn/nn/octree_pool.py +200 -200
  36. ocnn/octree/__init__.py +22 -22
  37. ocnn/octree/octree.py +770 -770
  38. ocnn/octree/points.py +384 -323
  39. ocnn/octree/shuffled_key.py +115 -115
  40. ocnn/utils.py +205 -205
  41. {ocnn-2.2.8.dist-info → ocnn-2.3.0.dist-info}/METADATA +117 -111
  42. ocnn-2.3.0.dist-info/RECORD +45 -0
  43. {ocnn-2.2.8.dist-info → ocnn-2.3.0.dist-info}/WHEEL +1 -1
  44. {ocnn-2.2.8.dist-info → ocnn-2.3.0.dist-info}/licenses/LICENSE +21 -21
  45. ocnn-2.2.8.dist-info/RECORD +0 -36
  46. {ocnn-2.2.8.dist-info → ocnn-2.3.0.dist-info}/top_level.txt +0 -0
ocnn/nn/octree_pad.py CHANGED
@@ -1,39 +1,39 @@
1
- # --------------------------------------------------------
2
- # Octree-based Sparse Convolutional Neural Networks
3
- # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
- # Licensed under The MIT License [see LICENSE for details]
5
- # Written by Peng-Shuai Wang
6
- # --------------------------------------------------------
7
-
8
- import torch
9
-
10
- from ocnn.octree import Octree
11
-
12
-
13
- def octree_pad(data: torch.Tensor, octree: Octree, depth: int, val: float = 0.0):
14
- r''' Pads :attr:`val` to make the number of elements of :attr:`data` equal to
15
- the octree node number.
16
-
17
- Args:
18
- data (torch.Tensor): The input tensor with its number of elements equal to the
19
- non-empty octree node number.
20
- octree (Octree): The corresponding octree.
21
- depth (int): The depth of current octree.
22
- val (float): The padding value. (Default: :obj:`0.0`)
23
- '''
24
-
25
- idx = octree.nempty_index(depth)
26
- size = (octree.nnum[depth], data.shape[1]) # (N, C)
27
- out = torch.full(size, val, dtype=data.dtype, device=data.device)
28
- out[idx] = data
29
- return out
30
-
31
-
32
- def octree_depad(data: torch.Tensor, octree: Octree, depth: int):
33
- r''' Reverse operation of :func:`octree_depad`.
34
-
35
- Please refer to :func:`octree_depad` for the meaning of the arguments.
36
- '''
37
-
38
- idx = octree.nempty_index(depth)
39
- return data[idx]
1
+ # --------------------------------------------------------
2
+ # Octree-based Sparse Convolutional Neural Networks
3
+ # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Peng-Shuai Wang
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+
10
+ from ocnn.octree import Octree
11
+
12
+
13
+ def octree_pad(data: torch.Tensor, octree: Octree, depth: int, val: float = 0.0):
14
+ r''' Pads :attr:`val` to make the number of elements of :attr:`data` equal to
15
+ the octree node number.
16
+
17
+ Args:
18
+ data (torch.Tensor): The input tensor with its number of elements equal to the
19
+ non-empty octree node number.
20
+ octree (Octree): The corresponding octree.
21
+ depth (int): The depth of current octree.
22
+ val (float): The padding value. (Default: :obj:`0.0`)
23
+ '''
24
+
25
+ idx = octree.nempty_index(depth)
26
+ size = (octree.nnum[depth], data.shape[1]) # (N, C)
27
+ out = torch.full(size, val, dtype=data.dtype, device=data.device)
28
+ out[idx] = data
29
+ return out
30
+
31
+
32
+ def octree_depad(data: torch.Tensor, octree: Octree, depth: int):
33
+ r''' Reverse operation of :func:`octree_depad`.
34
+
35
+ Please refer to :func:`octree_depad` for the meaning of the arguments.
36
+ '''
37
+
38
+ idx = octree.nempty_index(depth)
39
+ return data[idx]
ocnn/nn/octree_pool.py CHANGED
@@ -1,200 +1,200 @@
1
- # --------------------------------------------------------
2
- # Octree-based Sparse Convolutional Neural Networks
3
- # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
- # Licensed under The MIT License [see LICENSE for details]
5
- # Written by Peng-Shuai Wang
6
- # --------------------------------------------------------
7
-
8
- import torch
9
- import torch.nn
10
- from typing import List
11
-
12
- from ocnn.octree import Octree
13
- from ocnn.utils import meshgrid, scatter_add, resize_with_last_val, list2str
14
- from . import octree_pad, octree_depad
15
-
16
-
17
- def octree_max_pool(data: torch.Tensor, octree: Octree, depth: int,
18
- nempty: bool = False, return_indices: bool = False):
19
- r''' Performs octree max pooling with kernel size 2 and stride 2.
20
-
21
- Args:
22
- data (torch.Tensor): The input tensor.
23
- octree (Octree): The corresponding octree.
24
- depth (int): The depth of current octree. After pooling, the corresponding
25
- depth decreased by 1.
26
- nempty (bool): If True, :attr:`data` contains only features of non-empty
27
- octree nodes.
28
- return_indices (bool): If True, returns the indices, which can be used in
29
- :func:`octree_max_unpool`.
30
- '''
31
-
32
- if nempty:
33
- data = octree_pad(data, octree, depth, float('-inf'))
34
- data = data.view(-1, 8, data.shape[1])
35
- out, indices = data.max(dim=1)
36
- if not nempty:
37
- out = octree_pad(out, octree, depth-1)
38
- return (out, indices) if return_indices else out
39
-
40
-
41
- def octree_max_unpool(data: torch.Tensor, indices: torch.Tensor, octree: Octree,
42
- depth: int, nempty: bool = False):
43
- r''' Performs octree max unpooling.
44
-
45
- Args:
46
- data (torch.Tensor): The input tensor.
47
- indices (torch.Tensor): The indices returned by :func:`octree_max_pool`. The
48
- depth of :attr:`indices` is larger by 1 than :attr:`data`.
49
- octree (Octree): The corresponding octree.
50
- depth (int): The depth of current data. After unpooling, the corresponding
51
- depth increases by 1.
52
- '''
53
-
54
- if not nempty:
55
- data = octree_depad(data, octree, depth)
56
- num, channel = data.shape
57
- out = torch.zeros(num, 8, channel, dtype=data.dtype, device=data.device)
58
- i = torch.arange(num, dtype=indices.dtype, device=indices.device)
59
- k = torch.arange(channel, dtype=indices.dtype, device=indices.device)
60
- i, k = meshgrid(i, k, indexing='ij')
61
- out[i, indices, k] = data
62
- out = out.view(-1, channel)
63
- if nempty:
64
- out = octree_depad(out, octree, depth+1)
65
- return out
66
-
67
-
68
- def octree_avg_pool(data: torch.Tensor, octree: Octree, depth: int,
69
- kernel: str, stride: int = 2, nempty: bool = False):
70
- r''' Performs octree average pooling.
71
-
72
- Args:
73
- data (torch.Tensor): The input tensor.
74
- octree (Octree): The corresponding octree.
75
- depth (int): The depth of current octree.
76
- kernel (str): The kernel size, like '333', '222'.
77
- stride (int): The stride of the pooling.
78
- nempty (bool): If True, :attr:`data` contains only features of non-empty
79
- octree nodes.
80
- '''
81
-
82
- neigh = octree.get_neigh(depth, kernel, stride, nempty)
83
-
84
- N1 = data.shape[0]
85
- N2 = neigh.shape[0]
86
- K = neigh.shape[1]
87
-
88
- mask = neigh >= 0
89
- val = 1.0 / (torch.sum(mask, dim=1) + 1e-8)
90
- mask = mask.view(-1)
91
- val = val.unsqueeze(1).repeat(1, K).reshape(-1)
92
- val = val[mask]
93
-
94
- row = torch.arange(N2, device=neigh.device)
95
- row = row.unsqueeze(1).repeat(1, K).view(-1)
96
- col = neigh.view(-1)
97
- indices = torch.stack([row[mask], col[mask]], dim=0).long()
98
-
99
- mat = torch.sparse_coo_tensor(indices, val, [N2, N1], device=data.device)
100
- out = torch.sparse.mm(mat, data)
101
- return out
102
-
103
-
104
- def octree_global_pool(data: torch.Tensor, octree: Octree, depth: int,
105
- nempty: bool = False):
106
- r''' Performs octree global average pooling.
107
-
108
- Args:
109
- data (torch.Tensor): The input tensor.
110
- octree (Octree): The corresponding octree.
111
- depth (int): The depth of current octree.
112
- nempty (bool): If True, :attr:`data` contains only features of non-empty
113
- octree nodes.
114
- '''
115
-
116
- batch_size = octree.batch_size
117
- batch_id = octree.batch_id(depth, nempty)
118
- ones = data.new_ones(data.shape[0], 1)
119
- count = scatter_add(ones, batch_id, dim=0, dim_size=batch_size)
120
- count[count < 1] = 1 # there might be 0 element in some shapes
121
-
122
- out = scatter_add(data, batch_id, dim=0, dim_size=batch_size)
123
- out = out / count
124
- return out
125
-
126
-
127
- class OctreePoolBase(torch.nn.Module):
128
- r''' The base class for octree-based pooling.
129
- '''
130
-
131
- def __init__(self, kernel_size: List[int], stride: int, nempty: bool = False):
132
- super().__init__()
133
- self.kernel_size = resize_with_last_val(kernel_size)
134
- self.kernel = list2str(self.kernel_size)
135
- self.stride = stride
136
- self.nempty = nempty
137
-
138
- def extra_repr(self) -> str:
139
- return ('kernel_size={}, stride={}, nempty={}').format(
140
- self.kernel_size, self.stride, self.nempty) # noqa
141
-
142
-
143
- class OctreeMaxPool(OctreePoolBase):
144
- r''' Performs octree max pooling.
145
-
146
- Please refer to :func:`octree_max_pool` for details.
147
- '''
148
-
149
- def __init__(self, nempty: bool = False, return_indices: bool = False):
150
- super().__init__(kernel_size=[2], stride=2, nempty=nempty)
151
- self.return_indices = return_indices
152
-
153
- def forward(self, data: torch.Tensor, octree: Octree, depth: int):
154
- r''''''
155
-
156
- return octree_max_pool(data, octree, depth, self.nempty, self.return_indices)
157
-
158
-
159
- class OctreeMaxUnpool(OctreePoolBase):
160
- r''' Performs octree max unpooling.
161
-
162
- Please refer to :func:`octree_max_unpool` for details.
163
- '''
164
-
165
- def __init__(self, nempty: bool = False):
166
- super().__init__(kernel_size=[2], stride=2, nempty=nempty)
167
-
168
- def forward(self, data: torch.Tensor, indices: torch.Tensor, octree: Octree,
169
- depth: int):
170
- r''''''
171
-
172
- return octree_max_unpool(data, indices, octree, depth, self.nempty)
173
-
174
-
175
- class OctreeGlobalPool(OctreePoolBase):
176
- r''' Performs octree global pooling.
177
-
178
- Please refer to :func:`octree_global_pool` for details.
179
- '''
180
-
181
- def __init__(self, nempty: bool = False):
182
- super().__init__(kernel_size=[-1], stride=-1, nempty=nempty)
183
-
184
- def forward(self, data: torch.Tensor, octree: Octree, depth: int):
185
- r''''''
186
-
187
- return octree_global_pool(data, octree, depth, self.nempty)
188
-
189
-
190
- class OctreeAvgPool(OctreePoolBase):
191
- r''' Performs octree average pooling.
192
-
193
- Please refer to :func:`octree_avg_pool` for details.
194
- '''
195
-
196
- def forward(self, data: torch.Tensor, octree: Octree, depth: int):
197
- r''''''
198
-
199
- return octree_avg_pool(
200
- data, octree, depth, self.kernel, self.stride, self.nempty)
1
+ # --------------------------------------------------------
2
+ # Octree-based Sparse Convolutional Neural Networks
3
+ # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Peng-Shuai Wang
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ import torch.nn
10
+ from typing import List
11
+
12
+ from ocnn.octree import Octree
13
+ from ocnn.utils import meshgrid, scatter_add, resize_with_last_val, list2str
14
+ from . import octree_pad, octree_depad
15
+
16
+
17
+ def octree_max_pool(data: torch.Tensor, octree: Octree, depth: int,
18
+ nempty: bool = False, return_indices: bool = False):
19
+ r''' Performs octree max pooling with kernel size 2 and stride 2.
20
+
21
+ Args:
22
+ data (torch.Tensor): The input tensor.
23
+ octree (Octree): The corresponding octree.
24
+ depth (int): The depth of current octree. After pooling, the corresponding
25
+ depth decreased by 1.
26
+ nempty (bool): If True, :attr:`data` contains only features of non-empty
27
+ octree nodes.
28
+ return_indices (bool): If True, returns the indices, which can be used in
29
+ :func:`octree_max_unpool`.
30
+ '''
31
+
32
+ if nempty:
33
+ data = octree_pad(data, octree, depth, float('-inf'))
34
+ data = data.view(-1, 8, data.shape[1])
35
+ out, indices = data.max(dim=1)
36
+ if not nempty:
37
+ out = octree_pad(out, octree, depth-1)
38
+ return (out, indices) if return_indices else out
39
+
40
+
41
+ def octree_max_unpool(data: torch.Tensor, indices: torch.Tensor, octree: Octree,
42
+ depth: int, nempty: bool = False):
43
+ r''' Performs octree max unpooling.
44
+
45
+ Args:
46
+ data (torch.Tensor): The input tensor.
47
+ indices (torch.Tensor): The indices returned by :func:`octree_max_pool`. The
48
+ depth of :attr:`indices` is larger by 1 than :attr:`data`.
49
+ octree (Octree): The corresponding octree.
50
+ depth (int): The depth of current data. After unpooling, the corresponding
51
+ depth increases by 1.
52
+ '''
53
+
54
+ if not nempty:
55
+ data = octree_depad(data, octree, depth)
56
+ num, channel = data.shape
57
+ out = torch.zeros(num, 8, channel, dtype=data.dtype, device=data.device)
58
+ i = torch.arange(num, dtype=indices.dtype, device=indices.device)
59
+ k = torch.arange(channel, dtype=indices.dtype, device=indices.device)
60
+ i, k = meshgrid(i, k, indexing='ij')
61
+ out[i, indices, k] = data
62
+ out = out.view(-1, channel)
63
+ if nempty:
64
+ out = octree_depad(out, octree, depth+1)
65
+ return out
66
+
67
+
68
+ def octree_avg_pool(data: torch.Tensor, octree: Octree, depth: int,
69
+ kernel: str, stride: int = 2, nempty: bool = False):
70
+ r''' Performs octree average pooling.
71
+
72
+ Args:
73
+ data (torch.Tensor): The input tensor.
74
+ octree (Octree): The corresponding octree.
75
+ depth (int): The depth of current octree.
76
+ kernel (str): The kernel size, like '333', '222'.
77
+ stride (int): The stride of the pooling.
78
+ nempty (bool): If True, :attr:`data` contains only features of non-empty
79
+ octree nodes.
80
+ '''
81
+
82
+ neigh = octree.get_neigh(depth, kernel, stride, nempty)
83
+
84
+ N1 = data.shape[0]
85
+ N2 = neigh.shape[0]
86
+ K = neigh.shape[1]
87
+
88
+ mask = neigh >= 0
89
+ val = 1.0 / (torch.sum(mask, dim=1) + 1e-8)
90
+ mask = mask.view(-1)
91
+ val = val.unsqueeze(1).repeat(1, K).reshape(-1)
92
+ val = val[mask]
93
+
94
+ row = torch.arange(N2, device=neigh.device)
95
+ row = row.unsqueeze(1).repeat(1, K).view(-1)
96
+ col = neigh.view(-1)
97
+ indices = torch.stack([row[mask], col[mask]], dim=0).long()
98
+
99
+ mat = torch.sparse_coo_tensor(indices, val, [N2, N1], device=data.device)
100
+ out = torch.sparse.mm(mat, data)
101
+ return out
102
+
103
+
104
+ def octree_global_pool(data: torch.Tensor, octree: Octree, depth: int,
105
+ nempty: bool = False):
106
+ r''' Performs octree global average pooling.
107
+
108
+ Args:
109
+ data (torch.Tensor): The input tensor.
110
+ octree (Octree): The corresponding octree.
111
+ depth (int): The depth of current octree.
112
+ nempty (bool): If True, :attr:`data` contains only features of non-empty
113
+ octree nodes.
114
+ '''
115
+
116
+ batch_size = octree.batch_size
117
+ batch_id = octree.batch_id(depth, nempty)
118
+ ones = data.new_ones(data.shape[0], 1)
119
+ count = scatter_add(ones, batch_id, dim=0, dim_size=batch_size)
120
+ count[count < 1] = 1 # there might be 0 element in some shapes
121
+
122
+ out = scatter_add(data, batch_id, dim=0, dim_size=batch_size)
123
+ out = out / count
124
+ return out
125
+
126
+
127
+ class OctreePoolBase(torch.nn.Module):
128
+ r''' The base class for octree-based pooling.
129
+ '''
130
+
131
+ def __init__(self, kernel_size: List[int], stride: int, nempty: bool = False):
132
+ super().__init__()
133
+ self.kernel_size = resize_with_last_val(kernel_size)
134
+ self.kernel = list2str(self.kernel_size)
135
+ self.stride = stride
136
+ self.nempty = nempty
137
+
138
+ def extra_repr(self) -> str:
139
+ return ('kernel_size={}, stride={}, nempty={}').format(
140
+ self.kernel_size, self.stride, self.nempty) # noqa
141
+
142
+
143
+ class OctreeMaxPool(OctreePoolBase):
144
+ r''' Performs octree max pooling.
145
+
146
+ Please refer to :func:`octree_max_pool` for details.
147
+ '''
148
+
149
+ def __init__(self, nempty: bool = False, return_indices: bool = False):
150
+ super().__init__(kernel_size=[2], stride=2, nempty=nempty)
151
+ self.return_indices = return_indices
152
+
153
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int):
154
+ r''''''
155
+
156
+ return octree_max_pool(data, octree, depth, self.nempty, self.return_indices)
157
+
158
+
159
+ class OctreeMaxUnpool(OctreePoolBase):
160
+ r''' Performs octree max unpooling.
161
+
162
+ Please refer to :func:`octree_max_unpool` for details.
163
+ '''
164
+
165
+ def __init__(self, nempty: bool = False):
166
+ super().__init__(kernel_size=[2], stride=2, nempty=nempty)
167
+
168
+ def forward(self, data: torch.Tensor, indices: torch.Tensor, octree: Octree,
169
+ depth: int):
170
+ r''''''
171
+
172
+ return octree_max_unpool(data, indices, octree, depth, self.nempty)
173
+
174
+
175
+ class OctreeGlobalPool(OctreePoolBase):
176
+ r''' Performs octree global pooling.
177
+
178
+ Please refer to :func:`octree_global_pool` for details.
179
+ '''
180
+
181
+ def __init__(self, nempty: bool = False):
182
+ super().__init__(kernel_size=[-1], stride=-1, nempty=nempty)
183
+
184
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int):
185
+ r''''''
186
+
187
+ return octree_global_pool(data, octree, depth, self.nempty)
188
+
189
+
190
+ class OctreeAvgPool(OctreePoolBase):
191
+ r''' Performs octree average pooling.
192
+
193
+ Please refer to :func:`octree_avg_pool` for details.
194
+ '''
195
+
196
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int):
197
+ r''''''
198
+
199
+ return octree_avg_pool(
200
+ data, octree, depth, self.kernel, self.stride, self.nempty)
ocnn/octree/__init__.py CHANGED
@@ -1,22 +1,22 @@
1
- # --------------------------------------------------------
2
- # Octree-based Sparse Convolutional Neural Networks
3
- # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
- # Licensed under The MIT License [see LICENSE for details]
5
- # Written by Peng-Shuai Wang
6
- # --------------------------------------------------------
7
-
8
- from .shuffled_key import key2xyz, xyz2key
9
- from .points import Points, merge_points
10
- from .octree import Octree, merge_octrees, init_octree
11
-
12
- __all__ = [
13
- 'key2xyz',
14
- 'xyz2key',
15
- 'Points',
16
- 'Octree',
17
- 'merge_points',
18
- 'merge_octrees',
19
- 'init_octree',
20
- ]
21
-
22
- classes = __all__
1
+ # --------------------------------------------------------
2
+ # Octree-based Sparse Convolutional Neural Networks
3
+ # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Peng-Shuai Wang
6
+ # --------------------------------------------------------
7
+
8
+ from .shuffled_key import key2xyz, xyz2key
9
+ from .points import Points, merge_points
10
+ from .octree import Octree, merge_octrees, init_octree
11
+
12
+ __all__ = [
13
+ 'key2xyz',
14
+ 'xyz2key',
15
+ 'Points',
16
+ 'Octree',
17
+ 'merge_points',
18
+ 'merge_octrees',
19
+ 'init_octree',
20
+ ]
21
+
22
+ classes = __all__