ocnn 2.2.5__py3-none-any.whl → 2.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ocnn/modules/modules.py CHANGED
@@ -1,303 +1,303 @@
1
- # --------------------------------------------------------
2
- # Octree-based Sparse Convolutional Neural Networks
3
- # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
- # Licensed under The MIT License [see LICENSE for details]
5
- # Written by Peng-Shuai Wang
6
- # --------------------------------------------------------
7
-
8
- import torch
9
- import torch.utils.checkpoint
10
- from typing import List
11
-
12
- from ocnn.nn import OctreeConv, OctreeDeconv, OctreeGroupNorm
13
- from ocnn.octree import Octree
14
-
15
-
16
- # bn_momentum, bn_eps = 0.01, 0.001 # the default value of Tensorflow 1.x
17
- # bn_momentum, bn_eps = 0.1, 1e-05 # the default value of pytorch
18
-
19
-
20
- def ckpt_conv_wrapper(conv_op, data, octree):
21
- # The dummy tensor is a workaround when the checkpoint is used for the first conv layer:
22
- # https://discuss.pytorch.org/t/checkpoint-with-no-grad-requiring-inputs-problem/19117/11
23
- dummy = torch.ones(1, dtype=torch.float32, requires_grad=True)
24
-
25
- def conv_wrapper(data, octree, dummy_tensor):
26
- return conv_op(data, octree)
27
-
28
- return torch.utils.checkpoint.checkpoint(conv_wrapper, data, octree, dummy)
29
-
30
-
31
- class OctreeConvBn(torch.nn.Module):
32
- r''' A sequence of :class:`OctreeConv` and :obj:`BatchNorm`.
33
-
34
- Please refer to :class:`ocnn.nn.OctreeConv` for details on the parameters.
35
- '''
36
-
37
- def __init__(self, in_channels: int, out_channels: int,
38
- kernel_size: List[int] = [3], stride: int = 1,
39
- nempty: bool = False):
40
- super().__init__()
41
- self.conv = OctreeConv(
42
- in_channels, out_channels, kernel_size, stride, nempty)
43
- self.bn = torch.nn.BatchNorm1d(out_channels)
44
-
45
- def forward(self, data: torch.Tensor, octree: Octree, depth: int):
46
- r''''''
47
-
48
- out = self.conv(data, octree, depth)
49
- out = self.bn(out)
50
- return out
51
-
52
-
53
- class OctreeConvBnRelu(torch.nn.Module):
54
- r''' A sequence of :class:`OctreeConv`, :obj:`BatchNorm`, and :obj:`Relu`.
55
-
56
- Please refer to :class:`ocnn.nn.OctreeConv` for details on the parameters.
57
- '''
58
-
59
- def __init__(self, in_channels: int, out_channels: int,
60
- kernel_size: List[int] = [3], stride: int = 1,
61
- nempty: bool = False):
62
- super().__init__()
63
- self.conv = OctreeConv(
64
- in_channels, out_channels, kernel_size, stride, nempty)
65
- self.bn = torch.nn.BatchNorm1d(out_channels)
66
- self.relu = torch.nn.ReLU(inplace=True)
67
-
68
- def forward(self, data: torch.Tensor, octree: Octree, depth: int):
69
- r''''''
70
-
71
- out = self.conv(data, octree, depth)
72
- out = self.bn(out)
73
- out = self.relu(out)
74
- return out
75
-
76
-
77
- class OctreeDeconvBnRelu(torch.nn.Module):
78
- r''' A sequence of :class:`OctreeDeconv`, :obj:`BatchNorm`, and :obj:`Relu`.
79
-
80
- Please refer to :class:`ocnn.nn.OctreeDeconv` for details on the parameters.
81
- '''
82
-
83
- def __init__(self, in_channels: int, out_channels: int,
84
- kernel_size: List[int] = [3], stride: int = 1,
85
- nempty: bool = False):
86
- super().__init__()
87
- self.deconv = OctreeDeconv(
88
- in_channels, out_channels, kernel_size, stride, nempty)
89
- self.bn = torch.nn.BatchNorm1d(out_channels)
90
- self.relu = torch.nn.ReLU(inplace=True)
91
-
92
- def forward(self, data: torch.Tensor, octree: Octree, depth: int):
93
- r''''''
94
-
95
- out = self.deconv(data, octree, depth)
96
- out = self.bn(out)
97
- out = self.relu(out)
98
- return out
99
-
100
-
101
- class Conv1x1(torch.nn.Module):
102
- r''' Performs a convolution with kernel :obj:`(1,1,1)`.
103
-
104
- The shape of octree features is :obj:`(N, C)`, where :obj:`N` is the node
105
- number and :obj:`C` is the feature channel. Therefore, :class:`Conv1x1` can be
106
- implemented with :class:`torch.nn.Linear`.
107
- '''
108
-
109
- def __init__(self, in_channels: int, out_channels: int, use_bias: bool = False):
110
- super().__init__()
111
- self.linear = torch.nn.Linear(in_channels, out_channels, use_bias)
112
-
113
- def forward(self, data: torch.Tensor):
114
- r''''''
115
-
116
- return self.linear(data)
117
-
118
-
119
- class Conv1x1Bn(torch.nn.Module):
120
- r''' A sequence of :class:`Conv1x1` and :class:`BatchNorm`.
121
- '''
122
-
123
- def __init__(self, in_channels: int, out_channels: int):
124
- super().__init__()
125
- self.conv = Conv1x1(in_channels, out_channels, use_bias=False)
126
- self.bn = torch.nn.BatchNorm1d(out_channels)
127
-
128
- def forward(self, data: torch.Tensor):
129
- r''''''
130
-
131
- out = self.conv(data)
132
- out = self.bn(out)
133
- return out
134
-
135
-
136
- class Conv1x1BnRelu(torch.nn.Module):
137
- r''' A sequence of :class:`Conv1x1`, :class:`BatchNorm` and :class:`Relu`.
138
- '''
139
-
140
- def __init__(self, in_channels: int, out_channels: int):
141
- super().__init__()
142
- self.conv = Conv1x1(in_channels, out_channels, use_bias=False)
143
- self.bn = torch.nn.BatchNorm1d(out_channels)
144
- self.relu = torch.nn.ReLU(inplace=True)
145
-
146
- def forward(self, data: torch.Tensor):
147
- r''''''
148
-
149
- out = self.conv(data)
150
- out = self.bn(out)
151
- out = self.relu(out)
152
- return out
153
-
154
-
155
- class FcBnRelu(torch.nn.Module):
156
- r''' A sequence of :class:`FC`, :class:`BatchNorm` and :class:`Relu`.
157
- '''
158
-
159
- def __init__(self, in_channels: int, out_channels: int):
160
- super().__init__()
161
- self.flatten = torch.nn.Flatten(start_dim=1)
162
- self.fc = torch.nn.Linear(in_channels, out_channels, bias=False)
163
- self.bn = torch.nn.BatchNorm1d(out_channels)
164
- self.relu = torch.nn.ReLU(inplace=True)
165
-
166
- def forward(self, data):
167
- r''''''
168
-
169
- out = self.flatten(data)
170
- out = self.fc(out)
171
- out = self.bn(out)
172
- out = self.relu(out)
173
- return out
174
-
175
-
176
- class OctreeConvGn(torch.nn.Module):
177
- r''' A sequence of :class:`OctreeConv` and :obj:`OctreeGroupNorm`.
178
-
179
- Please refer to :class:`ocnn.nn.OctreeConv` for details on the parameters.
180
- '''
181
-
182
- def __init__(self, in_channels: int, out_channels: int, group: int,
183
- kernel_size: List[int] = [3], stride: int = 1,
184
- nempty: bool = False):
185
- super().__init__()
186
- self.conv = OctreeConv(
187
- in_channels, out_channels, kernel_size, stride, nempty)
188
- self.gn = OctreeGroupNorm(out_channels, group=group, nempty=nempty)
189
-
190
- def forward(self, data: torch.Tensor, octree: Octree, depth: int):
191
- r''''''
192
-
193
- out = self.conv(data, octree, depth)
194
- out = self.gn(out, octree, depth)
195
- return out
196
-
197
-
198
- class OctreeConvGnRelu(torch.nn.Module):
199
- r''' A sequence of :class:`OctreeConv`, :obj:`OctreeGroupNorm`, and :obj:`Relu`.
200
-
201
- Please refer to :class:`ocnn.nn.OctreeConv` for details on the parameters.
202
- '''
203
-
204
- def __init__(self, in_channels: int, out_channels: int, group: int,
205
- kernel_size: List[int] = [3], stride: int = 1,
206
- nempty: bool = False):
207
- super().__init__()
208
- self.stride = stride
209
- self.conv = OctreeConv(
210
- in_channels, out_channels, kernel_size, stride, nempty)
211
- self.gn = OctreeGroupNorm(out_channels, group=group, nempty=nempty)
212
- self.relu = torch.nn.ReLU(inplace=True)
213
-
214
- def forward(self, data: torch.Tensor, octree: Octree, depth: int):
215
- r''''''
216
-
217
- out = self.conv(data, octree, depth)
218
- out = self.gn(out, octree, depth if self.stride == 1 else depth - 1)
219
- out = self.relu(out)
220
- return out
221
-
222
-
223
- class OctreeDeconvGnRelu(torch.nn.Module):
224
- r''' A sequence of :class:`OctreeDeconv`, :obj:`OctreeGroupNorm`, and :obj:`Relu`.
225
-
226
- Please refer to :class:`ocnn.nn.OctreeConv` for details on the parameters.
227
- '''
228
-
229
- def __init__(self, in_channels: int, out_channels: int, group: int,
230
- kernel_size: List[int] = [3], stride: int = 1,
231
- nempty: bool = False):
232
- super().__init__()
233
- self.stride = stride
234
- self.deconv = OctreeDeconv(
235
- in_channels, out_channels, kernel_size, stride, nempty)
236
- self.gn = OctreeGroupNorm(out_channels, group=group, nempty=nempty)
237
- self.relu = torch.nn.ReLU(inplace=True)
238
-
239
- def forward(self, data: torch.Tensor, octree: Octree, depth: int):
240
- r''''''
241
-
242
- out = self.deconv(data, octree, depth)
243
- out = self.gn(out, octree, depth if self.stride == 1 else depth + 1)
244
- out = self.relu(out)
245
- return out
246
-
247
-
248
- class Conv1x1Gn(torch.nn.Module):
249
- r''' A sequence of :class:`Conv1x1`, :class:`OctreeGroupNorm`.
250
- '''
251
-
252
- def __init__(self, in_channels: int, out_channels: int, group: int,
253
- nempty: bool = False):
254
- super().__init__()
255
- self.conv = Conv1x1(in_channels, out_channels, use_bias=False)
256
- self.gn = OctreeGroupNorm(out_channels, group=group, nempty=nempty)
257
-
258
- def forward(self, data: torch.Tensor, octree: Octree, depth: int):
259
- r''''''
260
-
261
- out = self.conv(data)
262
- out = self.gn(out, octree, depth)
263
- return out
264
-
265
-
266
- class Conv1x1GnRelu(torch.nn.Module):
267
- r''' A sequence of :class:`Conv1x1`, :class:`OctreeGroupNorm` and :class:`Relu`.
268
- '''
269
-
270
- def __init__(self, in_channels: int, out_channels: int, group: int,
271
- nempty: bool = False):
272
- super().__init__()
273
- self.conv = Conv1x1(in_channels, out_channels, use_bias=False)
274
- self.gn = OctreeGroupNorm(out_channels, group=group, nempty=nempty)
275
- self.relu = torch.nn.ReLU(inplace=True)
276
-
277
- def forward(self, data: torch.Tensor, octree: Octree, depth: int):
278
- r''''''
279
-
280
- out = self.conv(data)
281
- out = self.gn(out, octree, depth)
282
- out = self.relu(out)
283
- return out
284
-
285
-
286
- class InputFeature(torch.nn.Module):
287
- r''' Returns the initial input feature stored in octree.
288
-
289
- Refer to :func:`ocnn.octree.Octree.get_input_feature` for details.
290
- '''
291
-
292
- def __init__(self, feature: str = 'NDF', nempty: bool = False):
293
- super().__init__()
294
- self.nempty = nempty
295
- self.feature = feature.upper()
296
-
297
- def forward(self, octree: Octree):
298
- r''''''
299
- return octree.get_input_feature(self.feature, self.nempty)
300
-
301
- def extra_repr(self) -> str:
302
- r''''''
303
- return 'feature={}, nempty={}'.format(self.feature, self.nempty)
1
+ # --------------------------------------------------------
2
+ # Octree-based Sparse Convolutional Neural Networks
3
+ # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Peng-Shuai Wang
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ import torch.utils.checkpoint
10
+ from typing import List
11
+
12
+ from ocnn.nn import OctreeConv, OctreeDeconv, OctreeGroupNorm
13
+ from ocnn.octree import Octree
14
+
15
+
16
+ # bn_momentum, bn_eps = 0.01, 0.001 # the default value of Tensorflow 1.x
17
+ # bn_momentum, bn_eps = 0.1, 1e-05 # the default value of pytorch
18
+
19
+
20
+ def ckpt_conv_wrapper(conv_op, data, octree):
21
+ # The dummy tensor is a workaround when the checkpoint is used for the first conv layer:
22
+ # https://discuss.pytorch.org/t/checkpoint-with-no-grad-requiring-inputs-problem/19117/11
23
+ dummy = torch.ones(1, dtype=torch.float32, requires_grad=True)
24
+
25
+ def conv_wrapper(data, octree, dummy_tensor):
26
+ return conv_op(data, octree)
27
+
28
+ return torch.utils.checkpoint.checkpoint(conv_wrapper, data, octree, dummy)
29
+
30
+
31
+ class OctreeConvBn(torch.nn.Module):
32
+ r''' A sequence of :class:`OctreeConv` and :obj:`BatchNorm`.
33
+
34
+ Please refer to :class:`ocnn.nn.OctreeConv` for details on the parameters.
35
+ '''
36
+
37
+ def __init__(self, in_channels: int, out_channels: int,
38
+ kernel_size: List[int] = [3], stride: int = 1,
39
+ nempty: bool = False):
40
+ super().__init__()
41
+ self.conv = OctreeConv(
42
+ in_channels, out_channels, kernel_size, stride, nempty)
43
+ self.bn = torch.nn.BatchNorm1d(out_channels)
44
+
45
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int):
46
+ r''''''
47
+
48
+ out = self.conv(data, octree, depth)
49
+ out = self.bn(out)
50
+ return out
51
+
52
+
53
+ class OctreeConvBnRelu(torch.nn.Module):
54
+ r''' A sequence of :class:`OctreeConv`, :obj:`BatchNorm`, and :obj:`Relu`.
55
+
56
+ Please refer to :class:`ocnn.nn.OctreeConv` for details on the parameters.
57
+ '''
58
+
59
+ def __init__(self, in_channels: int, out_channels: int,
60
+ kernel_size: List[int] = [3], stride: int = 1,
61
+ nempty: bool = False):
62
+ super().__init__()
63
+ self.conv = OctreeConv(
64
+ in_channels, out_channels, kernel_size, stride, nempty)
65
+ self.bn = torch.nn.BatchNorm1d(out_channels)
66
+ self.relu = torch.nn.ReLU(inplace=True)
67
+
68
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int):
69
+ r''''''
70
+
71
+ out = self.conv(data, octree, depth)
72
+ out = self.bn(out)
73
+ out = self.relu(out)
74
+ return out
75
+
76
+
77
+ class OctreeDeconvBnRelu(torch.nn.Module):
78
+ r''' A sequence of :class:`OctreeDeconv`, :obj:`BatchNorm`, and :obj:`Relu`.
79
+
80
+ Please refer to :class:`ocnn.nn.OctreeDeconv` for details on the parameters.
81
+ '''
82
+
83
+ def __init__(self, in_channels: int, out_channels: int,
84
+ kernel_size: List[int] = [3], stride: int = 1,
85
+ nempty: bool = False):
86
+ super().__init__()
87
+ self.deconv = OctreeDeconv(
88
+ in_channels, out_channels, kernel_size, stride, nempty)
89
+ self.bn = torch.nn.BatchNorm1d(out_channels)
90
+ self.relu = torch.nn.ReLU(inplace=True)
91
+
92
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int):
93
+ r''''''
94
+
95
+ out = self.deconv(data, octree, depth)
96
+ out = self.bn(out)
97
+ out = self.relu(out)
98
+ return out
99
+
100
+
101
+ class Conv1x1(torch.nn.Module):
102
+ r''' Performs a convolution with kernel :obj:`(1,1,1)`.
103
+
104
+ The shape of octree features is :obj:`(N, C)`, where :obj:`N` is the node
105
+ number and :obj:`C` is the feature channel. Therefore, :class:`Conv1x1` can be
106
+ implemented with :class:`torch.nn.Linear`.
107
+ '''
108
+
109
+ def __init__(self, in_channels: int, out_channels: int, use_bias: bool = False):
110
+ super().__init__()
111
+ self.linear = torch.nn.Linear(in_channels, out_channels, use_bias)
112
+
113
+ def forward(self, data: torch.Tensor):
114
+ r''''''
115
+
116
+ return self.linear(data)
117
+
118
+
119
+ class Conv1x1Bn(torch.nn.Module):
120
+ r''' A sequence of :class:`Conv1x1` and :class:`BatchNorm`.
121
+ '''
122
+
123
+ def __init__(self, in_channels: int, out_channels: int):
124
+ super().__init__()
125
+ self.conv = Conv1x1(in_channels, out_channels, use_bias=False)
126
+ self.bn = torch.nn.BatchNorm1d(out_channels)
127
+
128
+ def forward(self, data: torch.Tensor):
129
+ r''''''
130
+
131
+ out = self.conv(data)
132
+ out = self.bn(out)
133
+ return out
134
+
135
+
136
+ class Conv1x1BnRelu(torch.nn.Module):
137
+ r''' A sequence of :class:`Conv1x1`, :class:`BatchNorm` and :class:`Relu`.
138
+ '''
139
+
140
+ def __init__(self, in_channels: int, out_channels: int):
141
+ super().__init__()
142
+ self.conv = Conv1x1(in_channels, out_channels, use_bias=False)
143
+ self.bn = torch.nn.BatchNorm1d(out_channels)
144
+ self.relu = torch.nn.ReLU(inplace=True)
145
+
146
+ def forward(self, data: torch.Tensor):
147
+ r''''''
148
+
149
+ out = self.conv(data)
150
+ out = self.bn(out)
151
+ out = self.relu(out)
152
+ return out
153
+
154
+
155
+ class FcBnRelu(torch.nn.Module):
156
+ r''' A sequence of :class:`FC`, :class:`BatchNorm` and :class:`Relu`.
157
+ '''
158
+
159
+ def __init__(self, in_channels: int, out_channels: int):
160
+ super().__init__()
161
+ self.flatten = torch.nn.Flatten(start_dim=1)
162
+ self.fc = torch.nn.Linear(in_channels, out_channels, bias=False)
163
+ self.bn = torch.nn.BatchNorm1d(out_channels)
164
+ self.relu = torch.nn.ReLU(inplace=True)
165
+
166
+ def forward(self, data):
167
+ r''''''
168
+
169
+ out = self.flatten(data)
170
+ out = self.fc(out)
171
+ out = self.bn(out)
172
+ out = self.relu(out)
173
+ return out
174
+
175
+
176
+ class OctreeConvGn(torch.nn.Module):
177
+ r''' A sequence of :class:`OctreeConv` and :obj:`OctreeGroupNorm`.
178
+
179
+ Please refer to :class:`ocnn.nn.OctreeConv` for details on the parameters.
180
+ '''
181
+
182
+ def __init__(self, in_channels: int, out_channels: int, group: int,
183
+ kernel_size: List[int] = [3], stride: int = 1,
184
+ nempty: bool = False):
185
+ super().__init__()
186
+ self.conv = OctreeConv(
187
+ in_channels, out_channels, kernel_size, stride, nempty)
188
+ self.gn = OctreeGroupNorm(out_channels, group=group, nempty=nempty)
189
+
190
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int):
191
+ r''''''
192
+
193
+ out = self.conv(data, octree, depth)
194
+ out = self.gn(out, octree, depth)
195
+ return out
196
+
197
+
198
+ class OctreeConvGnRelu(torch.nn.Module):
199
+ r''' A sequence of :class:`OctreeConv`, :obj:`OctreeGroupNorm`, and :obj:`Relu`.
200
+
201
+ Please refer to :class:`ocnn.nn.OctreeConv` for details on the parameters.
202
+ '''
203
+
204
+ def __init__(self, in_channels: int, out_channels: int, group: int,
205
+ kernel_size: List[int] = [3], stride: int = 1,
206
+ nempty: bool = False):
207
+ super().__init__()
208
+ self.stride = stride
209
+ self.conv = OctreeConv(
210
+ in_channels, out_channels, kernel_size, stride, nempty)
211
+ self.gn = OctreeGroupNorm(out_channels, group=group, nempty=nempty)
212
+ self.relu = torch.nn.ReLU(inplace=True)
213
+
214
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int):
215
+ r''''''
216
+
217
+ out = self.conv(data, octree, depth)
218
+ out = self.gn(out, octree, depth if self.stride == 1 else depth - 1)
219
+ out = self.relu(out)
220
+ return out
221
+
222
+
223
+ class OctreeDeconvGnRelu(torch.nn.Module):
224
+ r''' A sequence of :class:`OctreeDeconv`, :obj:`OctreeGroupNorm`, and :obj:`Relu`.
225
+
226
+ Please refer to :class:`ocnn.nn.OctreeConv` for details on the parameters.
227
+ '''
228
+
229
+ def __init__(self, in_channels: int, out_channels: int, group: int,
230
+ kernel_size: List[int] = [3], stride: int = 1,
231
+ nempty: bool = False):
232
+ super().__init__()
233
+ self.stride = stride
234
+ self.deconv = OctreeDeconv(
235
+ in_channels, out_channels, kernel_size, stride, nempty)
236
+ self.gn = OctreeGroupNorm(out_channels, group=group, nempty=nempty)
237
+ self.relu = torch.nn.ReLU(inplace=True)
238
+
239
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int):
240
+ r''''''
241
+
242
+ out = self.deconv(data, octree, depth)
243
+ out = self.gn(out, octree, depth if self.stride == 1 else depth + 1)
244
+ out = self.relu(out)
245
+ return out
246
+
247
+
248
+ class Conv1x1Gn(torch.nn.Module):
249
+ r''' A sequence of :class:`Conv1x1`, :class:`OctreeGroupNorm`.
250
+ '''
251
+
252
+ def __init__(self, in_channels: int, out_channels: int, group: int,
253
+ nempty: bool = False):
254
+ super().__init__()
255
+ self.conv = Conv1x1(in_channels, out_channels, use_bias=False)
256
+ self.gn = OctreeGroupNorm(out_channels, group=group, nempty=nempty)
257
+
258
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int):
259
+ r''''''
260
+
261
+ out = self.conv(data)
262
+ out = self.gn(out, octree, depth)
263
+ return out
264
+
265
+
266
+ class Conv1x1GnRelu(torch.nn.Module):
267
+ r''' A sequence of :class:`Conv1x1`, :class:`OctreeGroupNorm` and :class:`Relu`.
268
+ '''
269
+
270
+ def __init__(self, in_channels: int, out_channels: int, group: int,
271
+ nempty: bool = False):
272
+ super().__init__()
273
+ self.conv = Conv1x1(in_channels, out_channels, use_bias=False)
274
+ self.gn = OctreeGroupNorm(out_channels, group=group, nempty=nempty)
275
+ self.relu = torch.nn.ReLU(inplace=True)
276
+
277
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int):
278
+ r''''''
279
+
280
+ out = self.conv(data)
281
+ out = self.gn(out, octree, depth)
282
+ out = self.relu(out)
283
+ return out
284
+
285
+
286
+ class InputFeature(torch.nn.Module):
287
+ r''' Returns the initial input feature stored in octree.
288
+
289
+ Refer to :func:`ocnn.octree.Octree.get_input_feature` for details.
290
+ '''
291
+
292
+ def __init__(self, feature: str = 'NDF', nempty: bool = False):
293
+ super().__init__()
294
+ self.nempty = nempty
295
+ self.feature = feature.upper()
296
+
297
+ def forward(self, octree: Octree):
298
+ r''''''
299
+ return octree.get_input_feature(self.feature, self.nempty)
300
+
301
+ def extra_repr(self) -> str:
302
+ r''''''
303
+ return 'feature={}, nempty={}'.format(self.feature, self.nempty)