learning3d 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. learning3d/__init__.py +2 -0
  2. learning3d/data_utils/__init__.py +4 -0
  3. learning3d/data_utils/dataloaders.py +454 -0
  4. learning3d/data_utils/user_data.py +119 -0
  5. learning3d/examples/test_dcp.py +139 -0
  6. learning3d/examples/test_deepgmr.py +144 -0
  7. learning3d/examples/test_flownet.py +113 -0
  8. learning3d/examples/test_masknet.py +159 -0
  9. learning3d/examples/test_masknet2.py +162 -0
  10. learning3d/examples/test_pcn.py +118 -0
  11. learning3d/examples/test_pcrnet.py +120 -0
  12. learning3d/examples/test_pnlk.py +121 -0
  13. learning3d/examples/test_pointconv.py +126 -0
  14. learning3d/examples/test_pointnet.py +121 -0
  15. learning3d/examples/test_prnet.py +126 -0
  16. learning3d/examples/test_rpmnet.py +120 -0
  17. learning3d/examples/train_PointNetLK.py +240 -0
  18. learning3d/examples/train_dcp.py +249 -0
  19. learning3d/examples/train_deepgmr.py +244 -0
  20. learning3d/examples/train_flownet.py +259 -0
  21. learning3d/examples/train_masknet.py +239 -0
  22. learning3d/examples/train_pcn.py +216 -0
  23. learning3d/examples/train_pcrnet.py +228 -0
  24. learning3d/examples/train_pointconv.py +245 -0
  25. learning3d/examples/train_pointnet.py +244 -0
  26. learning3d/examples/train_prnet.py +229 -0
  27. learning3d/examples/train_rpmnet.py +228 -0
  28. learning3d/losses/__init__.py +12 -0
  29. learning3d/losses/chamfer_distance.py +51 -0
  30. learning3d/losses/classification.py +14 -0
  31. learning3d/losses/correspondence_loss.py +10 -0
  32. learning3d/losses/cuda/chamfer_distance/__init__.py +1 -0
  33. learning3d/losses/cuda/chamfer_distance/chamfer_distance.cpp +185 -0
  34. learning3d/losses/cuda/chamfer_distance/chamfer_distance.cu +209 -0
  35. learning3d/losses/cuda/chamfer_distance/chamfer_distance.py +66 -0
  36. learning3d/losses/cuda/emd_torch/pkg/emd_loss_layer.py +41 -0
  37. learning3d/losses/cuda/emd_torch/pkg/include/cuda/emd.cuh +347 -0
  38. learning3d/losses/cuda/emd_torch/pkg/include/cuda_helper.h +18 -0
  39. learning3d/losses/cuda/emd_torch/pkg/include/emd.h +54 -0
  40. learning3d/losses/cuda/emd_torch/pkg/layer/__init__.py +1 -0
  41. learning3d/losses/cuda/emd_torch/pkg/layer/emd_loss_layer.py +40 -0
  42. learning3d/losses/cuda/emd_torch/pkg/src/cuda/emd.cu +70 -0
  43. learning3d/losses/cuda/emd_torch/pkg/src/emd.cpp +1 -0
  44. learning3d/losses/cuda/emd_torch/setup.py +29 -0
  45. learning3d/losses/emd.py +16 -0
  46. learning3d/losses/frobenius_norm.py +21 -0
  47. learning3d/losses/rmse_features.py +16 -0
  48. learning3d/models/__init__.py +23 -0
  49. learning3d/models/classifier.py +41 -0
  50. learning3d/models/dcp.py +92 -0
  51. learning3d/models/deepgmr.py +165 -0
  52. learning3d/models/dgcnn.py +92 -0
  53. learning3d/models/flownet3d.py +446 -0
  54. learning3d/models/masknet.py +84 -0
  55. learning3d/models/masknet2.py +264 -0
  56. learning3d/models/pcn.py +164 -0
  57. learning3d/models/pcrnet.py +74 -0
  58. learning3d/models/pointconv.py +108 -0
  59. learning3d/models/pointnet.py +108 -0
  60. learning3d/models/pointnetlk.py +173 -0
  61. learning3d/models/pooling.py +15 -0
  62. learning3d/models/ppfnet.py +102 -0
  63. learning3d/models/prnet.py +431 -0
  64. learning3d/models/rpmnet.py +359 -0
  65. learning3d/models/segmentation.py +38 -0
  66. learning3d/ops/__init__.py +0 -0
  67. learning3d/ops/data_utils.py +45 -0
  68. learning3d/ops/invmat.py +134 -0
  69. learning3d/ops/quaternion.py +218 -0
  70. learning3d/ops/se3.py +157 -0
  71. learning3d/ops/sinc.py +229 -0
  72. learning3d/ops/so3.py +213 -0
  73. learning3d/ops/transform_functions.py +342 -0
  74. learning3d/utils/__init__.py +9 -0
  75. learning3d/utils/lib/build/lib.linux-x86_64-3.5/pointnet2_cuda.cpython-35m-x86_64-linux-gnu.so +0 -0
  76. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/ball_query.o +0 -0
  77. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/ball_query_gpu.o +0 -0
  78. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/group_points.o +0 -0
  79. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/group_points_gpu.o +0 -0
  80. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/interpolate.o +0 -0
  81. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/interpolate_gpu.o +0 -0
  82. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/pointnet2_api.o +0 -0
  83. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/sampling.o +0 -0
  84. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/sampling_gpu.o +0 -0
  85. learning3d/utils/lib/dist/pointnet2-0.0.0-py3.5-linux-x86_64.egg +0 -0
  86. learning3d/utils/lib/pointnet2.egg-info/SOURCES.txt +14 -0
  87. learning3d/utils/lib/pointnet2.egg-info/dependency_links.txt +1 -0
  88. learning3d/utils/lib/pointnet2.egg-info/top_level.txt +1 -0
  89. learning3d/utils/lib/pointnet2_modules.py +160 -0
  90. learning3d/utils/lib/pointnet2_utils.py +318 -0
  91. learning3d/utils/lib/pytorch_utils.py +236 -0
  92. learning3d/utils/lib/setup.py +23 -0
  93. learning3d/utils/lib/src/ball_query.cpp +25 -0
  94. learning3d/utils/lib/src/ball_query_gpu.cu +67 -0
  95. learning3d/utils/lib/src/ball_query_gpu.h +15 -0
  96. learning3d/utils/lib/src/cuda_utils.h +15 -0
  97. learning3d/utils/lib/src/group_points.cpp +36 -0
  98. learning3d/utils/lib/src/group_points_gpu.cu +86 -0
  99. learning3d/utils/lib/src/group_points_gpu.h +22 -0
  100. learning3d/utils/lib/src/interpolate.cpp +65 -0
  101. learning3d/utils/lib/src/interpolate_gpu.cu +233 -0
  102. learning3d/utils/lib/src/interpolate_gpu.h +36 -0
  103. learning3d/utils/lib/src/pointnet2_api.cpp +25 -0
  104. learning3d/utils/lib/src/sampling.cpp +46 -0
  105. learning3d/utils/lib/src/sampling_gpu.cu +253 -0
  106. learning3d/utils/lib/src/sampling_gpu.h +29 -0
  107. learning3d/utils/pointconv_util.py +382 -0
  108. learning3d/utils/ppfnet_util.py +244 -0
  109. learning3d/utils/svd.py +59 -0
  110. learning3d/utils/transformer.py +243 -0
  111. learning3d-0.0.1.dist-info/LICENSE +21 -0
  112. learning3d-0.0.1.dist-info/METADATA +271 -0
  113. learning3d-0.0.1.dist-info/RECORD +115 -0
  114. learning3d-0.0.1.dist-info/WHEEL +5 -0
  115. learning3d-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,228 @@
1
+ import argparse
2
+ import os
3
+ import sys
4
+ import logging
5
+ import numpy
6
+ import numpy as np
7
+ import torch
8
+ import torch.utils.data
9
+ import torchvision
10
+ from torch.utils.data import DataLoader
11
+ from tensorboardX import SummaryWriter
12
+ from tqdm import tqdm
13
+
14
+ # Only if the files are in example folder.
15
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
16
+ if BASE_DIR[-8:] == 'examples':
17
+ sys.path.append(os.path.join(BASE_DIR, os.pardir))
18
+ os.chdir(os.path.join(BASE_DIR, os.pardir))
19
+
20
+ from learning3d.models import RPMNet, PPFNet
21
+ from learning3d.losses import FrobeniusNormLoss, RMSEFeaturesLoss
22
+ from learning3d.data_utils import RegistrationData, ModelNet40Data
23
+
24
+ def _init_(args):
25
+ if not os.path.exists('checkpoints'):
26
+ os.makedirs('checkpoints')
27
+ if not os.path.exists('checkpoints/' + args.exp_name):
28
+ os.makedirs('checkpoints/' + args.exp_name)
29
+ if not os.path.exists('checkpoints/' + args.exp_name + '/' + 'models'):
30
+ os.makedirs('checkpoints/' + args.exp_name + '/' + 'models')
31
+ os.system('cp main.py checkpoints' + '/' + args.exp_name + '/' + 'main.py.backup')
32
+ os.system('cp model.py checkpoints' + '/' + args.exp_name + '/' + 'model.py.backup')
33
+
34
+
35
+ class IOStream:
36
+ def __init__(self, path):
37
+ self.f = open(path, 'a')
38
+
39
+ def cprint(self, text):
40
+ print(text)
41
+ self.f.write(text + '\n')
42
+ self.f.flush()
43
+
44
+ def close(self):
45
+ self.f.close()
46
+
47
+ def test_one_epoch(device, model, test_loader):
48
+ model.eval()
49
+ test_loss = 0.0
50
+ pred = 0.0
51
+ count = 0
52
+ for i, data in enumerate(tqdm(test_loader)):
53
+ template, source, igt = data
54
+
55
+ template = template.to(device)
56
+ source = source.to(device)
57
+ igt = igt.to(device)
58
+
59
+ output = model(template, source)
60
+ loss_val = FrobeniusNormLoss()(output['est_T'], igt) + RMSEFeaturesLoss()(output['r'])
61
+
62
+ test_loss += loss_val.item()
63
+ count += 1
64
+
65
+ test_loss = float(test_loss)/count
66
+ return test_loss
67
+
68
+ def test(args, model, test_loader, textio):
69
+ test_loss, test_accuracy = test_one_epoch(args.device, model, test_loader)
70
+ textio.cprint('Validation Loss: %f & Validation Accuracy: %f'%(test_loss, test_accuracy))
71
+
72
+ def train_one_epoch(device, model, train_loader, optimizer):
73
+ model.train()
74
+ train_loss = 0.0
75
+ pred = 0.0
76
+ count = 0
77
+ for i, data in enumerate(tqdm(train_loader)):
78
+ template, source, igt = data
79
+
80
+ template = template.to(device)
81
+ source = source.to(device)
82
+ igt = igt.to(device)
83
+
84
+ output = model(template, source)
85
+ loss_val = FrobeniusNormLoss()(output['est_T'], igt) + RMSEFeaturesLoss()(output['r'])
86
+ # print(loss_val.item())
87
+
88
+ # forward + backward + optimize
89
+ optimizer.zero_grad()
90
+ loss_val.backward()
91
+ optimizer.step()
92
+
93
+ train_loss += loss_val.item()
94
+ count += 1
95
+
96
+ train_loss = float(train_loss)/count
97
+ return train_loss
98
+
99
+ def train(args, model, train_loader, test_loader, boardio, textio, checkpoint):
100
+ learnable_params = filter(lambda p: p.requires_grad, model.parameters())
101
+ if args.optimizer == 'Adam':
102
+ optimizer = torch.optim.Adam(learnable_params)
103
+ else:
104
+ optimizer = torch.optim.SGD(learnable_params, lr=0.1)
105
+
106
+ if checkpoint is not None:
107
+ min_loss = checkpoint['min_loss']
108
+ optimizer.load_state_dict(checkpoint['optimizer'])
109
+
110
+ best_test_loss = np.inf
111
+
112
+ for epoch in range(args.start_epoch, args.epochs):
113
+ train_loss = train_one_epoch(args.device, model, train_loader, optimizer)
114
+ test_loss = test_one_epoch(args.device, model, test_loader)
115
+
116
+ if test_loss<best_test_loss:
117
+ best_test_loss = test_loss
118
+ snap = {'epoch': epoch + 1,
119
+ 'model': model.state_dict(),
120
+ 'min_loss': best_test_loss,
121
+ 'optimizer' : optimizer.state_dict(),}
122
+ torch.save(snap, 'checkpoints/%s/models/best_model_snap.t7' % (args.exp_name))
123
+ torch.save(model.state_dict(), 'checkpoints/%s/models/best_model.t7' % (args.exp_name))
124
+ torch.save(model.feature_model.state_dict(), 'checkpoints/%s/models/best_ptnet_model.t7' % (args.exp_name))
125
+
126
+ torch.save(snap, 'checkpoints/%s/models/model_snap.t7' % (args.exp_name))
127
+ torch.save(model.state_dict(), 'checkpoints/%s/models/model.t7' % (args.exp_name))
128
+ torch.save(model.feature_model.state_dict(), 'checkpoints/%s/models/ptnet_model.t7' % (args.exp_name))
129
+
130
+ boardio.add_scalar('Train Loss', train_loss, epoch+1)
131
+ boardio.add_scalar('Test Loss', test_loss, epoch+1)
132
+ boardio.add_scalar('Best Test Loss', best_test_loss, epoch+1)
133
+
134
+ textio.cprint('EPOCH:: %d, Traininig Loss: %f, Testing Loss: %f, Best Loss: %f'%(epoch+1, train_loss, test_loss, best_test_loss))
135
+
136
+ def options():
137
+ parser = argparse.ArgumentParser(description='Point Cloud Registration')
138
+ parser.add_argument('--exp_name', type=str, default='exp_rpmnet', metavar='N',
139
+ help='Name of the experiment')
140
+ parser.add_argument('--dataset_path', type=str, default='ModelNet40',
141
+ metavar='PATH', help='path to the input dataset') # like '/path/to/ModelNet40'
142
+ parser.add_argument('--eval', type=bool, default=False, help='Train or Evaluate the network.')
143
+
144
+ # settings for input data
145
+ parser.add_argument('--dataset_type', default='modelnet', choices=['modelnet', 'shapenet2'],
146
+ metavar='DATASET', help='dataset type (default: modelnet)')
147
+ parser.add_argument('--num_points', default=1024, type=int,
148
+ metavar='N', help='points in point-cloud (default: 1024)')
149
+
150
+ # settings for PointNet
151
+ parser.add_argument('--fine_tune_pointnet', default='tune', type=str, choices=['fixed', 'tune'],
152
+ help='train pointnet (default: tune)')
153
+ parser.add_argument('--transfer_ptnet_weights', default='./checkpoints/exp_classifier/models/best_ptnet_model.t7', type=str,
154
+ metavar='PATH', help='path to pointnet features file')
155
+ parser.add_argument('--emb_dims', default=1024, type=int,
156
+ metavar='K', help='dim. of the feature vector (default: 1024)')
157
+ parser.add_argument('--symfn', default='max', choices=['max', 'avg'],
158
+ help='symmetric function (default: max)')
159
+
160
+ # settings for on training
161
+ parser.add_argument('--seed', type=int, default=1234)
162
+ parser.add_argument('-j', '--workers', default=4, type=int,
163
+ metavar='N', help='number of data loading workers (default: 4)')
164
+ parser.add_argument('-b', '--batch_size', default=10, type=int,
165
+ metavar='N', help='mini-batch size (default: 32)')
166
+ parser.add_argument('--epochs', default=200, type=int,
167
+ metavar='N', help='number of total epochs to run')
168
+ parser.add_argument('--start_epoch', default=0, type=int,
169
+ metavar='N', help='manual epoch number (useful on restarts)')
170
+ parser.add_argument('--optimizer', default='Adam', choices=['Adam', 'SGD'],
171
+ metavar='METHOD', help='name of an optimizer (default: Adam)')
172
+ parser.add_argument('--resume', default='', type=str,
173
+ metavar='PATH', help='path to latest checkpoint (default: null (no-use))')
174
+ parser.add_argument('--pretrained', default='', type=str,
175
+ metavar='PATH', help='path to pretrained model file (default: null (no-use))')
176
+ parser.add_argument('--device', default='cuda:0', type=str,
177
+ metavar='DEVICE', help='use CUDA if available')
178
+
179
+ args = parser.parse_args()
180
+ return args
181
+
182
+ def main():
183
+ args = options()
184
+
185
+ torch.backends.cudnn.deterministic = True
186
+ torch.manual_seed(args.seed)
187
+ torch.cuda.manual_seed_all(args.seed)
188
+ np.random.seed(args.seed)
189
+
190
+ boardio = SummaryWriter(log_dir='checkpoints/' + args.exp_name)
191
+ _init_(args)
192
+
193
+ textio = IOStream('checkpoints/' + args.exp_name + '/run.log')
194
+ textio.cprint(str(args))
195
+
196
+
197
+ trainset = RegistrationData('RPMNet', ModelNet40Data(train=True, num_points=args.num_points, use_normals=True), partial_source=True, partial_template=True)
198
+ testset = RegistrationData('RPMNet', ModelNet40Data(train=False, num_points=args.num_points, use_normals=True), partial_source=True, partial_template=True)
199
+ train_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, drop_last=True, num_workers=args.workers)
200
+ test_loader = DataLoader(testset, batch_size=8, shuffle=False, drop_last=False, num_workers=args.workers)
201
+
202
+ if not torch.cuda.is_available():
203
+ args.device = 'cpu'
204
+ args.device = torch.device(args.device)
205
+
206
+ # Create RPMNet Model.
207
+ model = RPMNet(feature_model=PPFNet())
208
+ model = model.to(args.device)
209
+
210
+ checkpoint = None
211
+ if args.resume:
212
+ assert os.path.isfile(args.resume)
213
+ checkpoint = torch.load(args.resume)
214
+ args.start_epoch = checkpoint['epoch']
215
+ model.load_state_dict(checkpoint['model'])
216
+
217
+ if args.pretrained:
218
+ assert os.path.isfile(args.pretrained)
219
+ model.load_state_dict(torch.load(args.pretrained, map_location='cpu'))
220
+ model.to(args.device)
221
+
222
+ if args.eval:
223
+ test(args, model, test_loader, textio)
224
+ else:
225
+ train(args, model, train_loader, test_loader, boardio, textio, checkpoint)
226
+
227
+ if __name__ == '__main__':
228
+ main()
@@ -0,0 +1,12 @@
1
+ from .rmse_features import RMSEFeaturesLoss
2
+ from .frobenius_norm import FrobeniusNormLoss
3
+ from .classification import ClassificationLoss
4
+ from .correspondence_loss import CorrespondenceLoss
5
+ try:
6
+ from .emd import EMDLoss
7
+ except:
8
+ print("Sorry EMD loss is not compatible with your system!")
9
+ try:
10
+ from .chamfer_distance import ChamferDistanceLoss
11
+ except:
12
+ print("Sorry ChamferDistance loss is not compatible with your system!")
@@ -0,0 +1,51 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ def pairwise_distances(a: torch.Tensor, b: torch.Tensor, p=2):
6
+ """
7
+ Compute the pairwise distance_tensor matrix between a and b which both have size [m, n, d]. The result is a tensor of
8
+ size [m, n, n] whose entry [m, i, j] contains the distance_tensor between a[m, i, :] and b[m, j, :].
9
+ :param a: A tensor containing m batches of n points of dimension d. i.e. of size [m, n, d]
10
+ :param b: A tensor containing m batches of n points of dimension d. i.e. of size [m, n, d]
11
+ :param p: Norm to use for the distance_tensor
12
+ :return: A tensor containing the pairwise distance_tensor between each pair of inputs in a batch.
13
+ """
14
+
15
+ if len(a.shape) != 3:
16
+ raise ValueError("Invalid shape for a. Must be [m, n, d] but got", a.shape)
17
+ if len(b.shape) != 3:
18
+ raise ValueError("Invalid shape for a. Must be [m, n, d] but got", b.shape)
19
+ return (a.unsqueeze(2) - b.unsqueeze(1)).abs().pow(p).sum(3)
20
+
21
+ def chamfer(a, b):
22
+ """
23
+ Compute the chamfer distance between two sets of vectors, a, and b
24
+ :param a: A m-sized minibatch of point sets in R^d. i.e. shape [m, n_a, d]
25
+ :param b: A m-sized minibatch of point sets in R^d. i.e. shape [m, n_b, d]
26
+ :return: A [m] shaped tensor storing the Chamfer distance between each minibatch entry
27
+ """
28
+ M = pairwise_distances(a, b)
29
+ dist1 = torch.mean(torch.sqrt(M.min(1)[0]))
30
+ dist2 = torch.mean(torch.sqrt(M.min(2)[0]))
31
+ return (dist1 + dist2) / 2.0
32
+
33
+
34
+ def chamfer_distance(template: torch.Tensor, source: torch.Tensor):
35
+ try:
36
+ from .cuda.chamfer_distance import ChamferDistance
37
+ cost_p0_p1, cost_p1_p0 = ChamferDistance()(template, source)
38
+ cost_p0_p1 = torch.mean(torch.sqrt(cost_p0_p1))
39
+ cost_p1_p0 = torch.mean(torch.sqrt(cost_p1_p0))
40
+ chamfer_loss = (cost_p0_p1 + cost_p1_p0)/2.0
41
+ except:
42
+ chamfer_loss = chamfer(template, source)
43
+ return chamfer_loss
44
+
45
+
46
+ class ChamferDistanceLoss(nn.Module):
47
+ def __init__(self):
48
+ super(ChamferDistanceLoss, self).__init__()
49
+
50
+ def forward(self, template, source):
51
+ return chamfer_distance(template, source)
@@ -0,0 +1,14 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ def classification_loss(prediction: torch.Tensor, target: torch.Tensor):
6
+ return F.nll_loss(prediction, target)
7
+
8
+
9
+ class ClassificationLoss(nn.Module):
10
+ def __init__(self):
11
+ super(ClassificationLoss, self).__init__()
12
+
13
+ def forward(self, prediction, target):
14
+ return classification_loss(prediction, target)
@@ -0,0 +1,10 @@
1
+ import torch
2
+
3
+ class CorrespondenceLoss(torch.nn.Module):
4
+ def forward(self, template, source, corr_mat_pred, corr_mat):
5
+ # corr_mat: batch_size x num_template x num_source (ground truth correspondence matrix)
6
+ # corr_mat_pred: batch_size x num_source x num_template (predicted correspondence matrix)
7
+ batch_size, _, num_points_template = template.shape
8
+ _, _, num_points = source.shape
9
+ return torch.nn.functional.cross_entropy(corr_mat_pred.view(batch_size*num_points, num_points_template),
10
+ torch.argmax(corr_mat.transpose(1,2).reshape(-1, num_points_template), axis=1))
@@ -0,0 +1 @@
1
+ from .chamfer_distance import ChamferDistance
@@ -0,0 +1,185 @@
1
+ #include <torch/torch.h>
2
+
3
+ // CUDA forward declarations
4
+ int ChamferDistanceKernelLauncher(
5
+ const int b, const int n,
6
+ const float* xyz,
7
+ const int m,
8
+ const float* xyz2,
9
+ float* result,
10
+ int* result_i,
11
+ float* result2,
12
+ int* result2_i);
13
+
14
+ int ChamferDistanceGradKernelLauncher(
15
+ const int b, const int n,
16
+ const float* xyz1,
17
+ const int m,
18
+ const float* xyz2,
19
+ const float* grad_dist1,
20
+ const int* idx1,
21
+ const float* grad_dist2,
22
+ const int* idx2,
23
+ float* grad_xyz1,
24
+ float* grad_xyz2);
25
+
26
+
27
+ void chamfer_distance_forward_cuda(
28
+ const at::Tensor xyz1,
29
+ const at::Tensor xyz2,
30
+ const at::Tensor dist1,
31
+ const at::Tensor dist2,
32
+ const at::Tensor idx1,
33
+ const at::Tensor idx2)
34
+ {
35
+ ChamferDistanceKernelLauncher(xyz1.size(0), xyz1.size(1), xyz1.data<float>(),
36
+ xyz2.size(1), xyz2.data<float>(),
37
+ dist1.data<float>(), idx1.data<int>(),
38
+ dist2.data<float>(), idx2.data<int>());
39
+ }
40
+
41
+ void chamfer_distance_backward_cuda(
42
+ const at::Tensor xyz1,
43
+ const at::Tensor xyz2,
44
+ at::Tensor gradxyz1,
45
+ at::Tensor gradxyz2,
46
+ at::Tensor graddist1,
47
+ at::Tensor graddist2,
48
+ at::Tensor idx1,
49
+ at::Tensor idx2)
50
+ {
51
+ ChamferDistanceGradKernelLauncher(xyz1.size(0), xyz1.size(1), xyz1.data<float>(),
52
+ xyz2.size(1), xyz2.data<float>(),
53
+ graddist1.data<float>(), idx1.data<int>(),
54
+ graddist2.data<float>(), idx2.data<int>(),
55
+ gradxyz1.data<float>(), gradxyz2.data<float>());
56
+ }
57
+
58
+
59
+ void nnsearch(
60
+ const int b, const int n, const int m,
61
+ const float* xyz1,
62
+ const float* xyz2,
63
+ float* dist,
64
+ int* idx)
65
+ {
66
+ for (int i = 0; i < b; i++) {
67
+ for (int j = 0; j < n; j++) {
68
+ const float x1 = xyz1[(i*n+j)*3+0];
69
+ const float y1 = xyz1[(i*n+j)*3+1];
70
+ const float z1 = xyz1[(i*n+j)*3+2];
71
+ double best = 0;
72
+ int besti = 0;
73
+ for (int k = 0; k < m; k++) {
74
+ const float x2 = xyz2[(i*m+k)*3+0] - x1;
75
+ const float y2 = xyz2[(i*m+k)*3+1] - y1;
76
+ const float z2 = xyz2[(i*m+k)*3+2] - z1;
77
+ const double d=x2*x2+y2*y2+z2*z2;
78
+ if (k==0 || d < best){
79
+ best = d;
80
+ besti = k;
81
+ }
82
+ }
83
+ dist[i*n+j] = best;
84
+ idx[i*n+j] = besti;
85
+ }
86
+ }
87
+ }
88
+
89
+
90
+ void chamfer_distance_forward(
91
+ const at::Tensor xyz1,
92
+ const at::Tensor xyz2,
93
+ const at::Tensor dist1,
94
+ const at::Tensor dist2,
95
+ const at::Tensor idx1,
96
+ const at::Tensor idx2)
97
+ {
98
+ const int batchsize = xyz1.size(0);
99
+ const int n = xyz1.size(1);
100
+ const int m = xyz2.size(1);
101
+
102
+ const float* xyz1_data = xyz1.data<float>();
103
+ const float* xyz2_data = xyz2.data<float>();
104
+ float* dist1_data = dist1.data<float>();
105
+ float* dist2_data = dist2.data<float>();
106
+ int* idx1_data = idx1.data<int>();
107
+ int* idx2_data = idx2.data<int>();
108
+
109
+ nnsearch(batchsize, n, m, xyz1_data, xyz2_data, dist1_data, idx1_data);
110
+ nnsearch(batchsize, m, n, xyz2_data, xyz1_data, dist2_data, idx2_data);
111
+ }
112
+
113
+
114
+ void chamfer_distance_backward(
115
+ const at::Tensor xyz1,
116
+ const at::Tensor xyz2,
117
+ at::Tensor gradxyz1,
118
+ at::Tensor gradxyz2,
119
+ at::Tensor graddist1,
120
+ at::Tensor graddist2,
121
+ at::Tensor idx1,
122
+ at::Tensor idx2)
123
+ {
124
+ const int b = xyz1.size(0);
125
+ const int n = xyz1.size(1);
126
+ const int m = xyz2.size(1);
127
+
128
+ const float* xyz1_data = xyz1.data<float>();
129
+ const float* xyz2_data = xyz2.data<float>();
130
+ float* gradxyz1_data = gradxyz1.data<float>();
131
+ float* gradxyz2_data = gradxyz2.data<float>();
132
+ float* graddist1_data = graddist1.data<float>();
133
+ float* graddist2_data = graddist2.data<float>();
134
+ const int* idx1_data = idx1.data<int>();
135
+ const int* idx2_data = idx2.data<int>();
136
+
137
+ for (int i = 0; i < b*n*3; i++)
138
+ gradxyz1_data[i] = 0;
139
+ for (int i = 0; i < b*m*3; i++)
140
+ gradxyz2_data[i] = 0;
141
+ for (int i = 0;i < b; i++) {
142
+ for (int j = 0; j < n; j++) {
143
+ const float x1 = xyz1_data[(i*n+j)*3+0];
144
+ const float y1 = xyz1_data[(i*n+j)*3+1];
145
+ const float z1 = xyz1_data[(i*n+j)*3+2];
146
+ const int j2 = idx1_data[i*n+j];
147
+
148
+ const float x2 = xyz2_data[(i*m+j2)*3+0];
149
+ const float y2 = xyz2_data[(i*m+j2)*3+1];
150
+ const float z2 = xyz2_data[(i*m+j2)*3+2];
151
+ const float g = graddist1_data[i*n+j]*2;
152
+
153
+ gradxyz1_data[(i*n+j)*3+0] += g*(x1-x2);
154
+ gradxyz1_data[(i*n+j)*3+1] += g*(y1-y2);
155
+ gradxyz1_data[(i*n+j)*3+2] += g*(z1-z2);
156
+ gradxyz2_data[(i*m+j2)*3+0] -= (g*(x1-x2));
157
+ gradxyz2_data[(i*m+j2)*3+1] -= (g*(y1-y2));
158
+ gradxyz2_data[(i*m+j2)*3+2] -= (g*(z1-z2));
159
+ }
160
+ for (int j = 0; j < m; j++) {
161
+ const float x1 = xyz2_data[(i*m+j)*3+0];
162
+ const float y1 = xyz2_data[(i*m+j)*3+1];
163
+ const float z1 = xyz2_data[(i*m+j)*3+2];
164
+ const int j2 = idx2_data[i*m+j];
165
+ const float x2 = xyz1_data[(i*n+j2)*3+0];
166
+ const float y2 = xyz1_data[(i*n+j2)*3+1];
167
+ const float z2 = xyz1_data[(i*n+j2)*3+2];
168
+ const float g = graddist2_data[i*m+j]*2;
169
+ gradxyz2_data[(i*m+j)*3+0] += g*(x1-x2);
170
+ gradxyz2_data[(i*m+j)*3+1] += g*(y1-y2);
171
+ gradxyz2_data[(i*m+j)*3+2] += g*(z1-z2);
172
+ gradxyz1_data[(i*n+j2)*3+0] -= (g*(x1-x2));
173
+ gradxyz1_data[(i*n+j2)*3+1] -= (g*(y1-y2));
174
+ gradxyz1_data[(i*n+j2)*3+2] -= (g*(z1-z2));
175
+ }
176
+ }
177
+ }
178
+
179
+
180
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
181
+ m.def("forward", &chamfer_distance_forward, "ChamferDistance forward");
182
+ m.def("forward_cuda", &chamfer_distance_forward_cuda, "ChamferDistance forward (CUDA)");
183
+ m.def("backward", &chamfer_distance_backward, "ChamferDistance backward");
184
+ m.def("backward_cuda", &chamfer_distance_backward_cuda, "ChamferDistance backward (CUDA)");
185
+ }