ocnn 2.2.6__py3-none-any.whl → 2.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ocnn/__init__.py CHANGED
@@ -1,24 +1,24 @@
1
- # --------------------------------------------------------
2
- # Octree-based Sparse Convolutional Neural Networks
3
- # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
- # Licensed under The MIT License [see LICENSE for details]
5
- # Written by Peng-Shuai Wang
6
- # --------------------------------------------------------
7
-
8
- from . import octree
9
- from . import nn
10
- from . import modules
11
- from . import models
12
- from . import dataset
13
- from . import utils
14
-
15
- __version__ = '2.2.6'
16
-
17
- __all__ = [
18
- 'octree',
19
- 'nn',
20
- 'modules',
21
- 'models',
22
- 'dataset',
23
- 'utils'
24
- ]
1
+ # --------------------------------------------------------
2
+ # Octree-based Sparse Convolutional Neural Networks
3
+ # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Peng-Shuai Wang
6
+ # --------------------------------------------------------
7
+
8
+ from . import octree
9
+ from . import nn
10
+ from . import modules
11
+ from . import models
12
+ from . import dataset
13
+ from . import utils
14
+
15
+ __version__ = '2.2.7'
16
+
17
+ __all__ = [
18
+ 'octree',
19
+ 'nn',
20
+ 'modules',
21
+ 'models',
22
+ 'dataset',
23
+ 'utils'
24
+ ]
ocnn/dataset.py CHANGED
@@ -1,160 +1,160 @@
1
- # --------------------------------------------------------
2
- # Octree-based Sparse Convolutional Neural Networks
3
- # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
- # Licensed under The MIT License [see LICENSE for details]
5
- # Written by Peng-Shuai Wang
6
- # --------------------------------------------------------
7
-
8
- import torch
9
-
10
- import ocnn
11
- from ocnn.octree import Octree, Points
12
-
13
-
14
- __all__ = ['Transform', 'CollateBatch']
15
- classes = __all__
16
-
17
-
18
- class Transform:
19
- r''' A boilerplate class which transforms an input data for :obj:`ocnn`.
20
- The input data is first converted to :class:`Points`, then randomly transformed
21
- (if enabled), and converted to an :class:`Octree`.
22
-
23
- Args:
24
- depth (int): The octree depth.
25
- full_depth (int): The octree layers with a depth small than
26
- :attr:`full_depth` are forced to be full.
27
- distort (bool): If true, performs the data augmentation.
28
- angle (list): A list of 3 float values to generate random rotation angles.
29
- interval (list): A list of 3 float values to represent the interval of
30
- rotation angles.
31
- scale (float): The maximum relative scale factor.
32
- uniform (bool): If true, performs uniform scaling.
33
- jittor (float): The maximum jitter values.
34
- orient_normal (str): Orient point normals along the specified axis, which is
35
- useful when normals are not oriented.
36
- '''
37
-
38
- def __init__(self, depth: int, full_depth: int, distort: bool, angle: list,
39
- interval: list, scale: float, uniform: bool, jitter: float,
40
- flip: list, orient_normal: str = '', **kwargs):
41
- super().__init__()
42
-
43
- # for octree building
44
- self.depth = depth
45
- self.full_depth = full_depth
46
-
47
- # for data augmentation
48
- self.distort = distort
49
- self.angle = angle
50
- self.interval = interval
51
- self.scale = scale
52
- self.uniform = uniform
53
- self.jitter = jitter
54
- self.flip = flip
55
-
56
- # for other transformations
57
- self.orient_normal = orient_normal
58
-
59
- def __call__(self, sample: dict, idx: int):
60
- r''''''
61
-
62
- output = self.preprocess(sample, idx)
63
- output = self.transform(output, idx)
64
- output['octree'] = self.points2octree(output['points'])
65
- return output
66
-
67
- def preprocess(self, sample: dict, idx: int):
68
- r''' Transforms :attr:`sample` to :class:`Points` and performs some specific
69
- transformations, like normalization.
70
- '''
71
-
72
- xyz = torch.from_numpy(sample.pop('points'))
73
- normals = torch.from_numpy(sample.pop('normals'))
74
- sample['points'] = Points(xyz, normals)
75
- return sample
76
-
77
- def transform(self, sample: dict, idx: int):
78
- r''' Applies the general transformations provided by :obj:`ocnn`.
79
- '''
80
-
81
- # The augmentations including rotation, scaling, and jittering.
82
- points = sample['points']
83
- if self.distort:
84
- rng_angle, rng_scale, rng_jitter, rnd_flip = self.rnd_parameters()
85
- points.flip(rnd_flip)
86
- points.rotate(rng_angle)
87
- points.translate(rng_jitter)
88
- points.scale(rng_scale)
89
-
90
- if self.orient_normal:
91
- points.orient_normal(self.orient_normal)
92
-
93
- # !!! NOTE: Clip the point cloud to [-1, 1] before building the octree
94
- inbox_mask = points.clip(min=-1, max=1)
95
- sample.update({'points': points, 'inbox_mask': inbox_mask})
96
- return sample
97
-
98
- def points2octree(self, points: Points):
99
- r''' Converts the input :attr:`points` to an octree.
100
- '''
101
-
102
- octree = Octree(self.depth, self.full_depth)
103
- octree.build_octree(points)
104
- return octree
105
-
106
- def rnd_parameters(self):
107
- r''' Generates random parameters for data augmentation.
108
- '''
109
-
110
- rnd_angle = [None] * 3
111
- for i in range(3):
112
- rot_num = self.angle[i] // self.interval[i]
113
- rnd = torch.randint(low=-rot_num, high=rot_num+1, size=(1,))
114
- rnd_angle[i] = rnd * self.interval[i] * (3.14159265 / 180.0)
115
- rnd_angle = torch.cat(rnd_angle)
116
-
117
- rnd_scale = torch.rand(3) * (2 * self.scale) - self.scale + 1.0
118
- if self.uniform:
119
- rnd_scale[1] = rnd_scale[0]
120
- rnd_scale[2] = rnd_scale[0]
121
-
122
- rnd_flip = ''
123
- for i, c in enumerate('xyz'):
124
- if torch.rand([1]) < self.flip[i]:
125
- rnd_flip = rnd_flip + c
126
-
127
- rnd_jitter = torch.rand(3) * (2 * self.jitter) - self.jitter
128
- return rnd_angle, rnd_scale, rnd_jitter, rnd_flip
129
-
130
-
131
- class CollateBatch:
132
- r''' Merge a list of octrees and points into a batch.
133
- '''
134
-
135
- def __init__(self, merge_points: bool = False):
136
- self.merge_points = merge_points
137
-
138
- def __call__(self, batch: list):
139
- assert type(batch) == list
140
-
141
- outputs = {}
142
- for key in batch[0].keys():
143
- outputs[key] = [b[key] for b in batch]
144
-
145
- # Merge a batch of octrees into one super octree
146
- if 'octree' in key:
147
- octree = ocnn.octree.merge_octrees(outputs[key])
148
- # NOTE: remember to construct the neighbor indices
149
- octree.construct_all_neigh()
150
- outputs[key] = octree
151
-
152
- # Merge a batch of points
153
- if 'points' in key and self.merge_points:
154
- outputs[key] = ocnn.octree.merge_points(outputs[key])
155
-
156
- # Convert the labels to a Tensor
157
- if 'label' in key:
158
- outputs['label'] = torch.tensor(outputs[key])
159
-
160
- return outputs
1
+ # --------------------------------------------------------
2
+ # Octree-based Sparse Convolutional Neural Networks
3
+ # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Peng-Shuai Wang
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+
10
+ import ocnn
11
+ from ocnn.octree import Octree, Points
12
+
13
+
14
+ __all__ = ['Transform', 'CollateBatch']
15
+ classes = __all__
16
+
17
+
18
+ class Transform:
19
+ r''' A boilerplate class which transforms an input data for :obj:`ocnn`.
20
+ The input data is first converted to :class:`Points`, then randomly transformed
21
+ (if enabled), and converted to an :class:`Octree`.
22
+
23
+ Args:
24
+ depth (int): The octree depth.
25
+ full_depth (int): The octree layers with a depth small than
26
+ :attr:`full_depth` are forced to be full.
27
+ distort (bool): If true, performs the data augmentation.
28
+ angle (list): A list of 3 float values to generate random rotation angles.
29
+ interval (list): A list of 3 float values to represent the interval of
30
+ rotation angles.
31
+ scale (float): The maximum relative scale factor.
32
+ uniform (bool): If true, performs uniform scaling.
33
+ jittor (float): The maximum jitter values.
34
+ orient_normal (str): Orient point normals along the specified axis, which is
35
+ useful when normals are not oriented.
36
+ '''
37
+
38
+ def __init__(self, depth: int, full_depth: int, distort: bool, angle: list,
39
+ interval: list, scale: float, uniform: bool, jitter: float,
40
+ flip: list, orient_normal: str = '', **kwargs):
41
+ super().__init__()
42
+
43
+ # for octree building
44
+ self.depth = depth
45
+ self.full_depth = full_depth
46
+
47
+ # for data augmentation
48
+ self.distort = distort
49
+ self.angle = angle
50
+ self.interval = interval
51
+ self.scale = scale
52
+ self.uniform = uniform
53
+ self.jitter = jitter
54
+ self.flip = flip
55
+
56
+ # for other transformations
57
+ self.orient_normal = orient_normal
58
+
59
+ def __call__(self, sample: dict, idx: int):
60
+ r''''''
61
+
62
+ output = self.preprocess(sample, idx)
63
+ output = self.transform(output, idx)
64
+ output['octree'] = self.points2octree(output['points'])
65
+ return output
66
+
67
+ def preprocess(self, sample: dict, idx: int):
68
+ r''' Transforms :attr:`sample` to :class:`Points` and performs some specific
69
+ transformations, like normalization.
70
+ '''
71
+
72
+ xyz = torch.from_numpy(sample.pop('points'))
73
+ normals = torch.from_numpy(sample.pop('normals'))
74
+ sample['points'] = Points(xyz, normals)
75
+ return sample
76
+
77
+ def transform(self, sample: dict, idx: int):
78
+ r''' Applies the general transformations provided by :obj:`ocnn`.
79
+ '''
80
+
81
+ # The augmentations including rotation, scaling, and jittering.
82
+ points = sample['points']
83
+ if self.distort:
84
+ rng_angle, rng_scale, rng_jitter, rnd_flip = self.rnd_parameters()
85
+ points.flip(rnd_flip)
86
+ points.rotate(rng_angle)
87
+ points.translate(rng_jitter)
88
+ points.scale(rng_scale)
89
+
90
+ if self.orient_normal:
91
+ points.orient_normal(self.orient_normal)
92
+
93
+ # !!! NOTE: Clip the point cloud to [-1, 1] before building the octree
94
+ inbox_mask = points.clip(min=-1, max=1)
95
+ sample.update({'points': points, 'inbox_mask': inbox_mask})
96
+ return sample
97
+
98
+ def points2octree(self, points: Points):
99
+ r''' Converts the input :attr:`points` to an octree.
100
+ '''
101
+
102
+ octree = Octree(self.depth, self.full_depth)
103
+ octree.build_octree(points)
104
+ return octree
105
+
106
+ def rnd_parameters(self):
107
+ r''' Generates random parameters for data augmentation.
108
+ '''
109
+
110
+ rnd_angle = [None] * 3
111
+ for i in range(3):
112
+ rot_num = self.angle[i] // self.interval[i]
113
+ rnd = torch.randint(low=-rot_num, high=rot_num+1, size=(1,))
114
+ rnd_angle[i] = rnd * self.interval[i] * (3.14159265 / 180.0)
115
+ rnd_angle = torch.cat(rnd_angle)
116
+
117
+ rnd_scale = torch.rand(3) * (2 * self.scale) - self.scale + 1.0
118
+ if self.uniform:
119
+ rnd_scale[1] = rnd_scale[0]
120
+ rnd_scale[2] = rnd_scale[0]
121
+
122
+ rnd_flip = ''
123
+ for i, c in enumerate('xyz'):
124
+ if torch.rand([1]) < self.flip[i]:
125
+ rnd_flip = rnd_flip + c
126
+
127
+ rnd_jitter = torch.rand(3) * (2 * self.jitter) - self.jitter
128
+ return rnd_angle, rnd_scale, rnd_jitter, rnd_flip
129
+
130
+
131
+ class CollateBatch:
132
+ r''' Merge a list of octrees and points into a batch.
133
+ '''
134
+
135
+ def __init__(self, merge_points: bool = False):
136
+ self.merge_points = merge_points
137
+
138
+ def __call__(self, batch: list):
139
+ assert type(batch) == list
140
+
141
+ outputs = {}
142
+ for key in batch[0].keys():
143
+ outputs[key] = [b[key] for b in batch]
144
+
145
+ # Merge a batch of octrees into one super octree
146
+ if 'octree' in key:
147
+ octree = ocnn.octree.merge_octrees(outputs[key])
148
+ # NOTE: remember to construct the neighbor indices
149
+ octree.construct_all_neigh()
150
+ outputs[key] = octree
151
+
152
+ # Merge a batch of points
153
+ if 'points' in key and self.merge_points:
154
+ outputs[key] = ocnn.octree.merge_points(outputs[key])
155
+
156
+ # Convert the labels to a Tensor
157
+ if 'label' in key:
158
+ outputs['label'] = torch.tensor(outputs[key])
159
+
160
+ return outputs
ocnn/models/__init__.py CHANGED
@@ -1,29 +1,29 @@
1
- # --------------------------------------------------------
2
- # Octree-based Sparse Convolutional Neural Networks
3
- # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
- # Licensed under The MIT License [see LICENSE for details]
5
- # Written by Peng-Shuai Wang
6
- # --------------------------------------------------------
7
-
8
- from .lenet import LeNet
9
- from .resnet import ResNet
10
- from .segnet import SegNet
11
- from .unet import UNet
12
- from .hrnet import HRNet
13
- from .autoencoder import AutoEncoder
14
- from .ounet import OUNet
15
- from .image2shape import Image2Shape
16
-
17
-
18
- __all__ = [
19
- 'LeNet',
20
- 'ResNet',
21
- 'SegNet',
22
- 'UNet',
23
- 'HRNet',
24
- 'AutoEncoder',
25
- 'OUNet',
26
- 'Image2Shape',
27
- ]
28
-
29
- classes = __all__
1
+ # --------------------------------------------------------
2
+ # Octree-based Sparse Convolutional Neural Networks
3
+ # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Peng-Shuai Wang
6
+ # --------------------------------------------------------
7
+
8
+ from .lenet import LeNet
9
+ from .resnet import ResNet
10
+ from .segnet import SegNet
11
+ from .unet import UNet
12
+ from .hrnet import HRNet
13
+ from .autoencoder import AutoEncoder
14
+ from .ounet import OUNet
15
+ from .image2shape import Image2Shape
16
+
17
+
18
+ __all__ = [
19
+ 'LeNet',
20
+ 'ResNet',
21
+ 'SegNet',
22
+ 'UNet',
23
+ 'HRNet',
24
+ 'AutoEncoder',
25
+ 'OUNet',
26
+ 'Image2Shape',
27
+ ]
28
+
29
+ classes = __all__