learning3d 0.0.2__py3-none-any.whl → 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. learning3d/__init__.py +0 -2
  2. learning3d/data_utils/dataloaders.py +11 -14
  3. learning3d/models/__init__.py +1 -6
  4. learning3d/utils/__init__.py +1 -6
  5. {learning3d-0.0.2.dist-info → learning3d-0.0.4.dist-info}/METADATA +1 -1
  6. {learning3d-0.0.2.dist-info → learning3d-0.0.4.dist-info}/RECORD +9 -44
  7. learning3d/examples/test_flownet.py +0 -113
  8. learning3d/examples/train_flownet.py +0 -259
  9. learning3d/models/flownet3d.py +0 -446
  10. learning3d/utils/lib/build/lib.linux-x86_64-3.5/pointnet2_cuda.cpython-35m-x86_64-linux-gnu.so +0 -0
  11. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/ball_query.o +0 -0
  12. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/ball_query_gpu.o +0 -0
  13. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/group_points.o +0 -0
  14. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/group_points_gpu.o +0 -0
  15. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/interpolate.o +0 -0
  16. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/interpolate_gpu.o +0 -0
  17. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/pointnet2_api.o +0 -0
  18. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/sampling.o +0 -0
  19. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/sampling_gpu.o +0 -0
  20. learning3d/utils/lib/dist/pointnet2-0.0.0-py3.5-linux-x86_64.egg +0 -0
  21. learning3d/utils/lib/pointnet2.egg-info/SOURCES.txt +0 -14
  22. learning3d/utils/lib/pointnet2.egg-info/dependency_links.txt +0 -1
  23. learning3d/utils/lib/pointnet2.egg-info/top_level.txt +0 -1
  24. learning3d/utils/lib/pointnet2_modules.py +0 -160
  25. learning3d/utils/lib/pointnet2_utils.py +0 -318
  26. learning3d/utils/lib/pytorch_utils.py +0 -236
  27. learning3d/utils/lib/setup.py +0 -23
  28. learning3d/utils/lib/src/ball_query.cpp +0 -25
  29. learning3d/utils/lib/src/ball_query_gpu.cu +0 -67
  30. learning3d/utils/lib/src/ball_query_gpu.h +0 -15
  31. learning3d/utils/lib/src/cuda_utils.h +0 -15
  32. learning3d/utils/lib/src/group_points.cpp +0 -36
  33. learning3d/utils/lib/src/group_points_gpu.cu +0 -86
  34. learning3d/utils/lib/src/group_points_gpu.h +0 -22
  35. learning3d/utils/lib/src/interpolate.cpp +0 -65
  36. learning3d/utils/lib/src/interpolate_gpu.cu +0 -233
  37. learning3d/utils/lib/src/interpolate_gpu.h +0 -36
  38. learning3d/utils/lib/src/pointnet2_api.cpp +0 -25
  39. learning3d/utils/lib/src/sampling.cpp +0 -46
  40. learning3d/utils/lib/src/sampling_gpu.cu +0 -253
  41. learning3d/utils/lib/src/sampling_gpu.h +0 -29
  42. {learning3d-0.0.2.dist-info → learning3d-0.0.4.dist-info}/LICENSE +0 -0
  43. {learning3d-0.0.2.dist-info → learning3d-0.0.4.dist-info}/WHEEL +0 -0
  44. {learning3d-0.0.2.dist-info → learning3d-0.0.4.dist-info}/top_level.txt +0 -0
@@ -1,446 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from time import time
5
- import numpy as np
6
-
7
- try:
8
- from .. utils import pointnet2_utils as pointutils
9
- except:
10
- print("Error in pointnet2_utils! Retry setup for pointnet2_utils.")
11
-
12
- def timeit(tag, t):
13
- print("{}: {}s".format(tag, time() - t))
14
- return time()
15
-
16
- def pc_normalize(pc):
17
- l = pc.shape[0]
18
- centroid = np.mean(pc, axis=0)
19
- pc = pc - centroid
20
- m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
21
- pc = pc / m
22
- return pc
23
-
24
- def square_distance(src, dst):
25
- """
26
- Calculate Euclid distance between each two points.
27
- src^T * dst = xn * xm + yn * ym + zn * zm;
28
- sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
29
- sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
30
- dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
31
- = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
32
- Input:
33
- src: source points, [B, N, C]
34
- dst: target points, [B, M, C]
35
- Output:
36
- dist: per-point square distance, [B, N, M]
37
- """
38
- B, N, _ = src.shape
39
- _, M, _ = dst.shape
40
- dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
41
- dist += torch.sum(src ** 2, -1).view(B, N, 1)
42
- dist += torch.sum(dst ** 2, -1).view(B, 1, M)
43
- return dist
44
-
45
-
46
- def index_points(points, idx):
47
- """
48
- Input:
49
- points: input points data, [B, N, C]
50
- idx: sample index data, [B, S]
51
- Return:
52
- new_points:, indexed points data, [B, S, C]
53
- """
54
- device = points.device
55
- B = points.shape[0]
56
- view_shape = list(idx.shape)
57
- view_shape[1:] = [1] * (len(view_shape) - 1)
58
- repeat_shape = list(idx.shape)
59
- repeat_shape[0] = 1
60
- batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
61
- new_points = points[batch_indices, idx, :]
62
- return new_points
63
-
64
-
65
- def farthest_point_sample(xyz, npoint):
66
- """
67
- Input:
68
- xyz: pointcloud data, [B, N, C]
69
- npoint: number of samples
70
- Return:
71
- centroids: sampled pointcloud index, [B, npoint]
72
- """
73
- device = xyz.device
74
- B, N, C = xyz.shape
75
- centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
76
- distance = torch.ones(B, N).to(device) * 1e10
77
- farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
78
- batch_indices = torch.arange(B, dtype=torch.long).to(device)
79
- for i in range(npoint):
80
- centroids[:, i] = farthest
81
- centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
82
- dist = torch.sum((xyz - centroid) ** 2, -1)
83
- mask = dist < distance
84
- distance[mask] = dist[mask]
85
- farthest = torch.max(distance, -1)[1]
86
- return centroids
87
-
88
- def knn_point(k, pos1, pos2):
89
- '''
90
- Input:
91
- k: int32, number of k in k-nn search
92
- pos1: (batch_size, ndataset, c) float32 array, input points
93
- pos2: (batch_size, npoint, c) float32 array, query points
94
- Output:
95
- val: (batch_size, npoint, k) float32 array, L2 distances
96
- idx: (batch_size, npoint, k) int32 array, indices to input points
97
- '''
98
- B, N, C = pos1.shape
99
- M = pos2.shape[1]
100
- pos1 = pos1.view(B,1,N,-1).repeat(1,M,1,1)
101
- pos2 = pos2.view(B,M,1,-1).repeat(1,1,N,1)
102
- dist = torch.sum(-(pos1-pos2)**2,-1)
103
- val,idx = dist.topk(k=k,dim = -1)
104
- return torch.sqrt(-val), idx
105
-
106
-
107
- def query_ball_point(radius, nsample, xyz, new_xyz):
108
- """
109
- Input:
110
- radius: local region radius
111
- nsample: max sample number in local region
112
- xyz: all points, [B, N, C]
113
- new_xyz: query points, [B, S, C]
114
- Return:
115
- group_idx: grouped points index, [B, S, nsample]
116
- """
117
- device = xyz.device
118
- B, N, C = xyz.shape
119
- _, S, _ = new_xyz.shape
120
- group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
121
- sqrdists = square_distance(new_xyz, xyz)
122
- group_idx[sqrdists > radius ** 2] = N
123
- mask = group_idx != N
124
- cnt = mask.sum(dim=-1)
125
- group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
126
- group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
127
- mask = group_idx == N
128
- group_idx[mask] = group_first[mask]
129
- return group_idx, cnt
130
-
131
-
132
- def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False):
133
- """
134
- Input:
135
- npoint:
136
- radius:
137
- nsample:
138
- xyz: input points position data, [B, N, C]
139
- points: input points data, [B, N, D]
140
- Return:
141
- new_xyz: sampled points position data, [B, 1, C]
142
- new_points: sampled points data, [B, 1, N, C+D]
143
- """
144
- B, N, C = xyz.shape
145
- S = npoint
146
- fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C]
147
- new_xyz = index_points(xyz, fps_idx)
148
- idx, _ = query_ball_point(radius, nsample, xyz, new_xyz)
149
- grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C]
150
- grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)
151
- if points is not None:
152
- grouped_points = index_points(points, idx)
153
- new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D]
154
- else:
155
- new_points = grouped_xyz_norm
156
- if returnfps:
157
- return new_xyz, new_points, grouped_xyz, fps_idx
158
- else:
159
- return new_xyz, new_points
160
-
161
-
162
- def sample_and_group_all(xyz, points):
163
- """
164
- Input:
165
- xyz: input points position data, [B, N, C]
166
- points: input points data, [B, N, D]
167
- Return:
168
- new_xyz: sampled points position data, [B, 1, C]
169
- new_points: sampled points data, [B, 1, N, C+D]
170
- """
171
- device = xyz.device
172
- B, N, C = xyz.shape
173
- new_xyz = torch.zeros(B, 1, C).to(device)
174
- grouped_xyz = xyz.view(B, 1, N, C)
175
- if points is not None:
176
- new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1)
177
- else:
178
- new_points = grouped_xyz
179
- return new_xyz, new_points
180
-
181
- class PointNetSetAbstraction(nn.Module):
182
- def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all):
183
- super(PointNetSetAbstraction, self).__init__()
184
- self.npoint = npoint
185
- self.radius = radius
186
- self.nsample = nsample
187
- self.group_all = group_all
188
- self.mlp_convs = nn.ModuleList()
189
- self.mlp_bns = nn.ModuleList()
190
- last_channel = in_channel+3 # TODO:
191
- for out_channel in mlp:
192
- self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1, bias = False))
193
- self.mlp_bns.append(nn.BatchNorm2d(out_channel))
194
- last_channel = out_channel
195
-
196
- if group_all:
197
- self.queryandgroup = pointutils.GroupAll()
198
- else:
199
- self.queryandgroup = pointutils.QueryAndGroup(radius, nsample)
200
-
201
- def forward(self, xyz, points):
202
- """
203
- Input:
204
- xyz: input points position data, [B, C, N]
205
- points: input points data, [B, D, N]
206
- Return:
207
- new_xyz: sampled points position data, [B, S, C]
208
- new_points_concat: sample points feature data, [B, S, D']
209
- """
210
- device = xyz.device
211
- B, C, N = xyz.shape
212
- xyz_t = xyz.permute(0, 2, 1).contiguous()
213
- # if points is not None:
214
- # points = points.permute(0, 2, 1).contiguous()
215
-
216
- # 选取邻域点
217
- if self.group_all == False:
218
- fps_idx = pointutils.furthest_point_sample(xyz_t, self.npoint) # [B, N]
219
- new_xyz = pointutils.gather_operation(xyz, fps_idx) # [B, C, N]
220
- else:
221
- new_xyz = xyz
222
- new_points = self.queryandgroup(xyz_t, new_xyz.transpose(2, 1).contiguous(), points) # [B, 3+C, N, S]
223
-
224
- # new_xyz: sampled points position data, [B, C, npoint]
225
- # new_points: sampled points data, [B, C+D, npoint, nsample]
226
- for i, conv in enumerate(self.mlp_convs):
227
- bn = self.mlp_bns[i]
228
- new_points = F.relu(bn(conv(new_points)))
229
-
230
- new_points = torch.max(new_points, -1)[0]
231
- return new_xyz, new_points
232
-
233
- class FlowEmbedding(nn.Module):
234
- def __init__(self, radius, nsample, in_channel, mlp, pooling='max', corr_func='concat', knn = True):
235
- super(FlowEmbedding, self).__init__()
236
- self.radius = radius
237
- self.nsample = nsample
238
- self.knn = knn
239
- self.pooling = pooling
240
- self.corr_func = corr_func
241
- self.mlp_convs = nn.ModuleList()
242
- self.mlp_bns = nn.ModuleList()
243
- if corr_func is 'concat':
244
- last_channel = in_channel*2+3
245
- for out_channel in mlp:
246
- self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1, bias=False))
247
- self.mlp_bns.append(nn.BatchNorm2d(out_channel))
248
- last_channel = out_channel
249
-
250
- def forward(self, pos1, pos2, feature1, feature2):
251
- """
252
- Input:
253
- xyz1: (batch_size, 3, npoint)
254
- xyz2: (batch_size, 3, npoint)
255
- feat1: (batch_size, channel, npoint)
256
- feat2: (batch_size, channel, npoint)
257
- Output:
258
- xyz1: (batch_size, 3, npoint)
259
- feat1_new: (batch_size, mlp[-1], npoint)
260
- """
261
- pos1_t = pos1.permute(0, 2, 1).contiguous()
262
- pos2_t = pos2.permute(0, 2, 1).contiguous()
263
- B, N, C = pos1_t.shape
264
- if self.knn:
265
- _, idx = pointutils.knn(self.nsample, pos1_t, pos2_t)
266
- else:
267
- # If the ball neighborhood points are less than nsample,
268
- # than use the knn neighborhood points
269
- idx, cnt = query_ball_point(self.radius, self.nsample, pos2_t, pos1_t)
270
- # 利用knn取最近的那些点
271
- _, idx_knn = pointutils.knn(self.nsample, pos1_t, pos2_t)
272
- cnt = cnt.view(B, -1, 1).repeat(1, 1, self.nsample)
273
- idx = idx_knn[cnt > (self.nsample-1)]
274
-
275
- pos2_grouped = pointutils.grouping_operation(pos2, idx) # [B, 3, N, S]
276
- pos_diff = pos2_grouped - pos1.view(B, -1, N, 1) # [B, 3, N, S]
277
-
278
- feat2_grouped = pointutils.grouping_operation(feature2, idx) # [B, C, N, S]
279
- if self.corr_func=='concat':
280
- feat_diff = torch.cat([feat2_grouped, feature1.view(B, -1, N, 1).repeat(1, 1, 1, self.nsample)], dim = 1)
281
-
282
- feat1_new = torch.cat([pos_diff, feat_diff], dim = 1) # [B, 2*C+3,N,S]
283
- for i, conv in enumerate(self.mlp_convs):
284
- bn = self.mlp_bns[i]
285
- feat1_new = F.relu(bn(conv(feat1_new)))
286
-
287
- feat1_new = torch.max(feat1_new, -1)[0] # [B, mlp[-1], npoint]
288
- return pos1, feat1_new
289
-
290
- class PointNetSetUpConv(nn.Module):
291
- def __init__(self, nsample, radius, f1_channel, f2_channel, mlp, mlp2, knn = True):
292
- super(PointNetSetUpConv, self).__init__()
293
- self.nsample = nsample
294
- self.radius = radius
295
- self.knn = knn
296
- self.mlp1_convs = nn.ModuleList()
297
- self.mlp2_convs = nn.ModuleList()
298
- last_channel = f2_channel+3
299
- for out_channel in mlp:
300
- self.mlp1_convs.append(nn.Sequential(nn.Conv2d(last_channel, out_channel, 1, bias=False),
301
- nn.BatchNorm2d(out_channel),
302
- nn.ReLU(inplace=False)))
303
- last_channel = out_channel
304
- if len(mlp) is not 0:
305
- last_channel = mlp[-1] + f1_channel
306
- else:
307
- last_channel = last_channel + f1_channel
308
- for out_channel in mlp2:
309
- self.mlp2_convs.append(nn.Sequential(nn.Conv1d(last_channel, out_channel, 1, bias=False),
310
- nn.BatchNorm1d(out_channel),
311
- nn.ReLU(inplace=False)))
312
- last_channel = out_channel
313
-
314
- def forward(self, pos1, pos2, feature1, feature2):
315
- """
316
- Feature propagation from xyz2 (less points) to xyz1 (more points)
317
- Inputs:
318
- xyz1: (batch_size, 3, npoint1)
319
- xyz2: (batch_size, 3, npoint2)
320
- feat1: (batch_size, channel1, npoint1) features for xyz1 points (earlier layers, more points)
321
- feat2: (batch_size, channel1, npoint2) features for xyz2 points
322
- Output:
323
- feat1_new: (batch_size, npoint2, mlp[-1] or mlp2[-1] or channel1+3)
324
- TODO: Add support for skip links. Study how delta(XYZ) plays a role in feature updating.
325
- """
326
- pos1_t = pos1.permute(0, 2, 1).contiguous()
327
- pos2_t = pos2.permute(0, 2, 1).contiguous()
328
- B,C,N = pos1.shape
329
- if self.knn:
330
- _, idx = pointutils.knn(self.nsample, pos1_t, pos2_t)
331
- else:
332
- idx, _ = query_ball_point(self.radius, self.nsample, pos2_t, pos1_t)
333
-
334
- pos2_grouped = pointutils.grouping_operation(pos2, idx)
335
- pos_diff = pos2_grouped - pos1.view(B, -1, N, 1) # [B,3,N1,S]
336
-
337
- feat2_grouped = pointutils.grouping_operation(feature2, idx)
338
- feat_new = torch.cat([feat2_grouped, pos_diff], dim = 1) # [B,C1+3,N1,S]
339
- for conv in self.mlp1_convs:
340
- feat_new = conv(feat_new)
341
- # max pooling
342
- feat_new = feat_new.max(-1)[0] # [B,mlp1[-1],N1]
343
- # concatenate feature in early layer
344
- if feature1 is not None:
345
- feat_new = torch.cat([feat_new, feature1], dim=1)
346
- # feat_new = feat_new.view(B,-1,N,1)
347
- for conv in self.mlp2_convs:
348
- feat_new = conv(feat_new)
349
-
350
- return feat_new
351
-
352
- class PointNetFeaturePropogation(nn.Module):
353
- def __init__(self, in_channel, mlp):
354
- super(PointNetFeaturePropogation, self).__init__()
355
- self.mlp_convs = nn.ModuleList()
356
- self.mlp_bns = nn.ModuleList()
357
- last_channel = in_channel
358
- for out_channel in mlp:
359
- self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))
360
- self.mlp_bns.append(nn.BatchNorm1d(out_channel))
361
- last_channel = out_channel
362
-
363
- def forward(self, pos1, pos2, feature1, feature2):
364
- """
365
- Input:
366
- xyz1: input points position data, [B, C, N]
367
- xyz2: sampled input points position data, [B, C, S]
368
- points1: input points data, [B, D, N]
369
- points2: input points data, [B, D, S]
370
- Return:
371
- new_points: upsampled points data, [B, D', N]
372
- """
373
- pos1_t = pos1.permute(0, 2, 1).contiguous()
374
- pos2_t = pos2.permute(0, 2, 1).contiguous()
375
- B, C, N = pos1.shape
376
-
377
- # dists = square_distance(pos1, pos2)
378
- # dists, idx = dists.sort(dim=-1)
379
- # dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3]
380
- dists,idx = pointutils.three_nn(pos1_t,pos2_t)
381
- dists[dists < 1e-10] = 1e-10
382
- weight = 1.0 / dists
383
- weight = weight / torch.sum(weight, -1,keepdim = True) # [B,N,3]
384
- interpolated_feat = torch.sum(pointutils.grouping_operation(feature2, idx) * weight.view(B, 1, N, 3), dim = -1) # [B,C,N,3]
385
-
386
- if feature1 is not None:
387
- feat_new = torch.cat([interpolated_feat, feature1], 1)
388
- else:
389
- feat_new = interpolated_feat
390
-
391
- for i, conv in enumerate(self.mlp_convs):
392
- bn = self.mlp_bns[i]
393
- feat_new = F.relu(bn(conv(feat_new)))
394
- return feat_new
395
-
396
-
397
- class FlowNet3D(nn.Module):
398
- def __init__(self):
399
- super(FlowNet3D, self).__init__()
400
-
401
- self.sa1 = PointNetSetAbstraction(npoint=1024, radius=0.5, nsample=16, in_channel=3, mlp=[32,32,64], group_all=False)
402
- self.sa2 = PointNetSetAbstraction(npoint=256, radius=1.0, nsample=16, in_channel=64, mlp=[64, 64, 128], group_all=False)
403
- self.sa3 = PointNetSetAbstraction(npoint=64, radius=2.0, nsample=8, in_channel=128, mlp=[128, 128, 256], group_all=False)
404
- self.sa4 = PointNetSetAbstraction(npoint=16, radius=4.0, nsample=8, in_channel=256, mlp=[256, 256, 512], group_all=False)
405
-
406
- self.fe_layer = FlowEmbedding(radius=10.0, nsample=64, in_channel = 128, mlp=[128, 128, 128], pooling='max', corr_func='concat')
407
-
408
- self.su1 = PointNetSetUpConv(nsample=8, radius=2.4, f1_channel = 256, f2_channel = 512, mlp=[], mlp2=[256, 256])
409
- self.su2 = PointNetSetUpConv(nsample=8, radius=1.2, f1_channel = 128+128, f2_channel = 256, mlp=[128, 128, 256], mlp2=[256])
410
- self.su3 = PointNetSetUpConv(nsample=8, radius=0.6, f1_channel = 64, f2_channel = 256, mlp=[128, 128, 256], mlp2=[256])
411
- self.fp = PointNetFeaturePropogation(in_channel = 256+3, mlp = [256, 256])
412
-
413
- self.conv1 = nn.Conv1d(256, 128, kernel_size=1, bias=False)
414
- self.bn1 = nn.BatchNorm1d(128)
415
- self.conv2=nn.Conv1d(128, 3, kernel_size=1, bias=True)
416
-
417
- def forward(self, pc1, pc2, feature1, feature2):
418
- l1_pc1, l1_feature1 = self.sa1(pc1, feature1)
419
- l2_pc1, l2_feature1 = self.sa2(l1_pc1, l1_feature1)
420
-
421
- l1_pc2, l1_feature2 = self.sa1(pc2, feature2)
422
- l2_pc2, l2_feature2 = self.sa2(l1_pc2, l1_feature2)
423
-
424
- _, l2_feature1_new = self.fe_layer(l2_pc1, l2_pc2, l2_feature1, l2_feature2)
425
-
426
- l3_pc1, l3_feature1 = self.sa3(l2_pc1, l2_feature1_new)
427
- l4_pc1, l4_feature1 = self.sa4(l3_pc1, l3_feature1)
428
-
429
- l3_fnew1 = self.su1(l3_pc1, l4_pc1, l3_feature1, l4_feature1)
430
- l2_fnew1 = self.su2(l2_pc1, l3_pc1, torch.cat([l2_feature1, l2_feature1_new], dim=1), l3_fnew1)
431
- l1_fnew1 = self.su3(l1_pc1, l2_pc1, l1_feature1, l2_fnew1)
432
- l0_fnew1 = self.fp(pc1, l1_pc1, feature1, l1_fnew1)
433
-
434
- x = F.relu(self.bn1(self.conv1(l0_fnew1)))
435
- sf = self.conv2(x)
436
- return sf
437
-
438
- if __name__ == '__main__':
439
- import os
440
- import torch
441
- os.environ["CUDA_VISIBLE_DEVICES"] = '0'
442
- input = torch.randn((8,3,2048))
443
- label = torch.randn(8,16)
444
- model = FlowNet3D()
445
- output = model(input,input)
446
- print(output.size())
@@ -1,14 +0,0 @@
1
- setup.py
2
- pointnet2.egg-info/PKG-INFO
3
- pointnet2.egg-info/SOURCES.txt
4
- pointnet2.egg-info/dependency_links.txt
5
- pointnet2.egg-info/top_level.txt
6
- src/ball_query.cpp
7
- src/ball_query_gpu.cu
8
- src/group_points.cpp
9
- src/group_points_gpu.cu
10
- src/interpolate.cpp
11
- src/interpolate_gpu.cu
12
- src/pointnet2_api.cpp
13
- src/sampling.cpp
14
- src/sampling_gpu.cu
@@ -1 +0,0 @@
1
- pointnet2_cuda
@@ -1,160 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
-
5
- from . import pointnet2_utils
6
- from . import pytorch_utils as pt_utils
7
- from typing import List
8
-
9
-
10
- class _PointnetSAModuleBase(nn.Module):
11
-
12
- def __init__(self):
13
- super().__init__()
14
- self.npoint = None
15
- self.groupers = None
16
- self.mlps = None
17
- self.pool_method = 'max_pool'
18
-
19
- def forward(self, xyz: torch.Tensor, features: torch.Tensor = None, new_xyz=None) -> (torch.Tensor, torch.Tensor):
20
- """
21
- :param xyz: (B, N, 3) tensor of the xyz coordinates of the features
22
- :param features: (B, N, C) tensor of the descriptors of the the features
23
- :param new_xyz:
24
- :return:
25
- new_xyz: (B, npoint, 3) tensor of the new features' xyz
26
- new_features: (B, npoint, \sum_k(mlps[k][-1])) tensor of the new_features descriptors
27
- """
28
- new_features_list = []
29
-
30
- xyz_flipped = xyz.transpose(1, 2).contiguous()
31
- if new_xyz is None:
32
- new_xyz = pointnet2_utils.gather_operation(
33
- xyz_flipped,
34
- pointnet2_utils.furthest_point_sample(xyz, self.npoint)
35
- ).transpose(1, 2).contiguous() if self.npoint is not None else None
36
-
37
- for i in range(len(self.groupers)):
38
- new_features = self.groupers[i](xyz, new_xyz, features) # (B, C, npoint, nsample)
39
-
40
- new_features = self.mlps[i](new_features) # (B, mlp[-1], npoint, nsample)
41
- if self.pool_method == 'max_pool':
42
- new_features = F.max_pool2d(
43
- new_features, kernel_size=[1, new_features.size(3)]
44
- ) # (B, mlp[-1], npoint, 1)
45
- elif self.pool_method == 'avg_pool':
46
- new_features = F.avg_pool2d(
47
- new_features, kernel_size=[1, new_features.size(3)]
48
- ) # (B, mlp[-1], npoint, 1)
49
- else:
50
- raise NotImplementedError
51
-
52
- new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint)
53
- new_features_list.append(new_features)
54
-
55
- return new_xyz, torch.cat(new_features_list, dim=1)
56
-
57
-
58
- class PointnetSAModuleMSG(_PointnetSAModuleBase):
59
- """Pointnet set abstraction layer with multiscale grouping"""
60
-
61
- def __init__(self, *, npoint: int, radii: List[float], nsamples: List[int], mlps: List[List[int]], bn: bool = True,
62
- use_xyz: bool = True, pool_method='max_pool', instance_norm=False):
63
- """
64
- :param npoint: int
65
- :param radii: list of float, list of radii to group with
66
- :param nsamples: list of int, number of samples in each ball query
67
- :param mlps: list of list of int, spec of the pointnet before the global pooling for each scale
68
- :param bn: whether to use batchnorm
69
- :param use_xyz:
70
- :param pool_method: max_pool / avg_pool
71
- :param instance_norm: whether to use instance_norm
72
- """
73
- super().__init__()
74
-
75
- assert len(radii) == len(nsamples) == len(mlps)
76
-
77
- self.npoint = npoint
78
- self.groupers = nn.ModuleList()
79
- self.mlps = nn.ModuleList()
80
- for i in range(len(radii)):
81
- radius = radii[i]
82
- nsample = nsamples[i]
83
- self.groupers.append(
84
- pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz)
85
- if npoint is not None else pointnet2_utils.GroupAll(use_xyz)
86
- )
87
- mlp_spec = mlps[i]
88
- if use_xyz:
89
- mlp_spec[0] += 3
90
-
91
- self.mlps.append(pt_utils.SharedMLP(mlp_spec, bn=bn, instance_norm=instance_norm))
92
- self.pool_method = pool_method
93
-
94
-
95
- class PointnetSAModule(PointnetSAModuleMSG):
96
- """Pointnet set abstraction layer"""
97
-
98
- def __init__(self, *, mlp: List[int], npoint: int = None, radius: float = None, nsample: int = None,
99
- bn: bool = True, use_xyz: bool = True, pool_method='max_pool', instance_norm=False):
100
- """
101
- :param mlp: list of int, spec of the pointnet before the global max_pool
102
- :param npoint: int, number of features
103
- :param radius: float, radius of ball
104
- :param nsample: int, number of samples in the ball query
105
- :param bn: whether to use batchnorm
106
- :param use_xyz:
107
- :param pool_method: max_pool / avg_pool
108
- :param instance_norm: whether to use instance_norm
109
- """
110
- super().__init__(
111
- mlps=[mlp], npoint=npoint, radii=[radius], nsamples=[nsample], bn=bn, use_xyz=use_xyz,
112
- pool_method=pool_method, instance_norm=instance_norm
113
- )
114
-
115
-
116
- class PointnetFPModule(nn.Module):
117
- r"""Propigates the features of one set to another"""
118
-
119
- def __init__(self, *, mlp: List[int], bn: bool = True):
120
- """
121
- :param mlp: list of int
122
- :param bn: whether to use batchnorm
123
- """
124
- super().__init__()
125
- self.mlp = pt_utils.SharedMLP(mlp, bn=bn)
126
-
127
- def forward(
128
- self, unknown: torch.Tensor, known: torch.Tensor, unknow_feats: torch.Tensor, known_feats: torch.Tensor
129
- ) -> torch.Tensor:
130
- """
131
- :param unknown: (B, n, 3) tensor of the xyz positions of the unknown features
132
- :param known: (B, m, 3) tensor of the xyz positions of the known features
133
- :param unknow_feats: (B, C1, n) tensor of the features to be propigated to
134
- :param known_feats: (B, C2, m) tensor of features to be propigated
135
- :return:
136
- new_features: (B, mlp[-1], n) tensor of the features of the unknown features
137
- """
138
- if known is not None:
139
- dist, idx = pointnet2_utils.three_nn(unknown, known)
140
- dist_recip = 1.0 / (dist + 1e-8)
141
- norm = torch.sum(dist_recip, dim=2, keepdim=True)
142
- weight = dist_recip / norm
143
-
144
- interpolated_feats = pointnet2_utils.three_interpolate(known_feats, idx, weight)
145
- else:
146
- interpolated_feats = known_feats.expand(*known_feats.size()[0:2], unknown.size(1))
147
-
148
- if unknow_feats is not None:
149
- new_features = torch.cat([interpolated_feats, unknow_feats], dim=1) # (B, C2 + C1, n)
150
- else:
151
- new_features = interpolated_feats
152
-
153
- new_features = new_features.unsqueeze(-1)
154
- new_features = self.mlp(new_features)
155
-
156
- return new_features.squeeze(-1)
157
-
158
-
159
- if __name__ == "__main__":
160
- pass