pyect 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyect-0.1.0/LICENSE +23 -0
- pyect-0.1.0/PKG-INFO +80 -0
- pyect-0.1.0/README.md +52 -0
- pyect-0.1.0/pyect/__init__.py +9 -0
- pyect-0.1.0/pyect/directions.py +28 -0
- pyect-0.1.0/pyect/dtypes.py +7 -0
- pyect-0.1.0/pyect/image_ecf.py +220 -0
- pyect-0.1.0/pyect/preprocessing/__init__.py +0 -0
- pyect-0.1.0/pyect/preprocessing/image_processing.py +242 -0
- pyect-0.1.0/pyect/tensor_complex.py +204 -0
- pyect-0.1.0/pyect/wecfs.py +65 -0
- pyect-0.1.0/pyect/wect.py +130 -0
- pyect-0.1.0/pyect.egg-info/PKG-INFO +80 -0
- pyect-0.1.0/pyect.egg-info/SOURCES.txt +18 -0
- pyect-0.1.0/pyect.egg-info/dependency_links.txt +1 -0
- pyect-0.1.0/pyect.egg-info/requires.txt +2 -0
- pyect-0.1.0/pyect.egg-info/top_level.txt +1 -0
- pyect-0.1.0/pyproject.toml +3 -0
- pyect-0.1.0/setup.cfg +4 -0
- pyect-0.1.0/setup.py +24 -0
pyect-0.1.0/LICENSE
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 Montana State University Computational Topology and
|
|
4
|
+
Geometry (CompTaG) Research Group.
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
8
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
9
|
+
in the Software without restriction, including without limitation the rights
|
|
10
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
11
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
12
|
+
furnished to do so, subject to the following conditions:
|
|
13
|
+
|
|
14
|
+
The above copyright notice and this permission notice shall be included in
|
|
15
|
+
all copies or substantial portions of the Software.
|
|
16
|
+
|
|
17
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
18
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
19
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
20
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
21
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
22
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
|
23
|
+
THE SOFTWARE.
|
pyect-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: pyect
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Generalized computation of the WECT using PyTorch.
|
|
5
|
+
Home-page: https://github.com/compTAG/pyECT
|
|
6
|
+
Author: Alex McCleary, Eli Quist, Jack Ruder, Jacob Sriraman
|
|
7
|
+
Author-email: eli.quist@student.montana.edu
|
|
8
|
+
License: MIT
|
|
9
|
+
Classifier: Programming Language :: Python :: 3
|
|
10
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
11
|
+
Classifier: Operating System :: OS Independent
|
|
12
|
+
Requires-Python: >=3.8
|
|
13
|
+
Description-Content-Type: text/markdown
|
|
14
|
+
License-File: LICENSE
|
|
15
|
+
Requires-Dist: torch
|
|
16
|
+
Requires-Dist: Pillow
|
|
17
|
+
Dynamic: author
|
|
18
|
+
Dynamic: author-email
|
|
19
|
+
Dynamic: classifier
|
|
20
|
+
Dynamic: description
|
|
21
|
+
Dynamic: description-content-type
|
|
22
|
+
Dynamic: home-page
|
|
23
|
+
Dynamic: license
|
|
24
|
+
Dynamic: license-file
|
|
25
|
+
Dynamic: requires-dist
|
|
26
|
+
Dynamic: requires-python
|
|
27
|
+
Dynamic: summary
|
|
28
|
+
|
|
29
|
+
# pyECT
|
|
30
|
+
|
|
31
|
+
The Weighted Euler Characteristic Transform (WECT) is a mathematical tool
|
|
32
|
+
used to analyze and summarize geometric and topological features of data.
|
|
33
|
+
This package provides an efficient and simple implementation of the WECT using
|
|
34
|
+
PyTorch.
|
|
35
|
+
|
|
36
|
+
This codebase accompanies the following paper (and should be cited if you use
|
|
37
|
+
this package):
|
|
38
|
+
|
|
39
|
+
```
|
|
40
|
+
TODO: Add Citation
|
|
41
|
+
```
|
|
42
|
+
|
|
43
|
+
## Installation
|
|
44
|
+
|
|
45
|
+
To install `pyECT`, use pip:
|
|
46
|
+
|
|
47
|
+
```bash
|
|
48
|
+
pip install pyect
|
|
49
|
+
```
|
|
50
|
+
|
|
51
|
+
## Usage
|
|
52
|
+
|
|
53
|
+
Here's a simple example of how to use `pyECT`:
|
|
54
|
+
|
|
55
|
+
```python
|
|
56
|
+
from pyect import WECT
|
|
57
|
+
|
|
58
|
+
# Example data and weight function
|
|
59
|
+
data = [...] # Replace with your data
|
|
60
|
+
weight_function = lambda x: x**2 # Replace with your weight function
|
|
61
|
+
|
|
62
|
+
# Compute the WECT
|
|
63
|
+
wect = WECT(data, weight_function)
|
|
64
|
+
result = wect.compute()
|
|
65
|
+
|
|
66
|
+
print("WECT result:", result)
|
|
67
|
+
```
|
|
68
|
+
|
|
69
|
+
For more detailed examples, please see the `/examples` directory.
|
|
70
|
+
|
|
71
|
+
## Contributing
|
|
72
|
+
|
|
73
|
+
Contributions are welcome! If you'd like to contribute, please fork the
|
|
74
|
+
repository and submit a pull request. For major changes, please open an issue
|
|
75
|
+
first to discuss what you'd like to change.
|
|
76
|
+
|
|
77
|
+
## License
|
|
78
|
+
|
|
79
|
+
This project is licensed under the MIT License. See the [LICENSE](LICENSE)
|
|
80
|
+
file for details.
|
pyect-0.1.0/README.md
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
# pyECT
|
|
2
|
+
|
|
3
|
+
The Weighted Euler Characteristic Transform (WECT) is a mathematical tool
|
|
4
|
+
used to analyze and summarize geometric and topological features of data.
|
|
5
|
+
This package provides an efficient and simple implementation of the WECT using
|
|
6
|
+
PyTorch.
|
|
7
|
+
|
|
8
|
+
This codebase accompanies the following paper (and should be cited if you use
|
|
9
|
+
this package):
|
|
10
|
+
|
|
11
|
+
```
|
|
12
|
+
TODO: Add Citation
|
|
13
|
+
```
|
|
14
|
+
|
|
15
|
+
## Installation
|
|
16
|
+
|
|
17
|
+
To install `pyECT`, use pip:
|
|
18
|
+
|
|
19
|
+
```bash
|
|
20
|
+
pip install pyect
|
|
21
|
+
```
|
|
22
|
+
|
|
23
|
+
## Usage
|
|
24
|
+
|
|
25
|
+
Here's a simple example of how to use `pyECT`:
|
|
26
|
+
|
|
27
|
+
```python
|
|
28
|
+
from pyect import WECT
|
|
29
|
+
|
|
30
|
+
# Example data and weight function
|
|
31
|
+
data = [...] # Replace with your data
|
|
32
|
+
weight_function = lambda x: x**2 # Replace with your weight function
|
|
33
|
+
|
|
34
|
+
# Compute the WECT
|
|
35
|
+
wect = WECT(data, weight_function)
|
|
36
|
+
result = wect.compute()
|
|
37
|
+
|
|
38
|
+
print("WECT result:", result)
|
|
39
|
+
```
|
|
40
|
+
|
|
41
|
+
For more detailed examples, please see the `/examples` directory.
|
|
42
|
+
|
|
43
|
+
## Contributing
|
|
44
|
+
|
|
45
|
+
Contributions are welcome! If you'd like to contribute, please fork the
|
|
46
|
+
repository and submit a pull request. For major changes, please open an issue
|
|
47
|
+
first to discuss what you'd like to change.
|
|
48
|
+
|
|
49
|
+
## License
|
|
50
|
+
|
|
51
|
+
This project is licensed under the MIT License. See the [LICENSE](LICENSE)
|
|
52
|
+
file for details.
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
from .wect import WECT
|
|
2
|
+
from .tensor_complex import Complex
|
|
3
|
+
from .directions import sample_directions_2d, sample_directions_3d
|
|
4
|
+
from .image_ecf import Image_ECF_2D, Image_ECF_3D
|
|
5
|
+
from .preprocessing.image_processing import (
|
|
6
|
+
weighted_freudenthal,
|
|
7
|
+
weighted_cubical,
|
|
8
|
+
image_to_grayscale_tensor
|
|
9
|
+
)
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
golden_angle = math.pi * (3.0 - math.sqrt(5.0))
|
|
5
|
+
|
|
6
|
+
def sample_directions_2d(num_dirs: int, *, device=None):
|
|
7
|
+
"""
|
|
8
|
+
Sample num_dirs directions evenly from S^1.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
angles = 2 * math.pi * torch.arange(num_dirs, dtype=torch.float32, device=device) / num_dirs
|
|
12
|
+
directions = torch.stack([torch.cos(angles), torch.sin(angles)], dim=-1)
|
|
13
|
+
return directions.contiguous()
|
|
14
|
+
|
|
15
|
+
def sample_directions_3d(num_dirs: int, *, device=None):
|
|
16
|
+
"""
|
|
17
|
+
Sample num_dirs directions from S^2 using the Fibonacci spiral method.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
i = torch.arange(num_dirs, dtype=torch.float32, device=device)
|
|
21
|
+
theta = golden_angle * i
|
|
22
|
+
y = 1.0 - (2.0 * (i + 0.5) / num_dirs)
|
|
23
|
+
r = torch.sqrt(torch.clamp(1.0 - y * y, min=0.0))
|
|
24
|
+
x = torch.cos(theta) * r
|
|
25
|
+
z = torch.sin(theta) * r
|
|
26
|
+
directions = torch.stack([x, y, z], dim=-1)
|
|
27
|
+
|
|
28
|
+
return directions.contiguous()
|
|
@@ -0,0 +1,220 @@
|
|
|
1
|
+
"""For computing the ECF of 2- and 3-dimensional images filtered by pixel intensity"""
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from typing import List
|
|
5
|
+
|
|
6
|
+
class Image_ECF_2D(torch.nn.Module):
|
|
7
|
+
"""A torch module for computing the ECF of a 2D image filtered by pixel intensity.
|
|
8
|
+
|
|
9
|
+
This module may be used just for computing the ECF of images, or used as a layer in a neural network.
|
|
10
|
+
Internally, the module stores the number of values used for sampling, so repeated forward calls
|
|
11
|
+
do not require this parameters to be passed in, and allow streamlined loading/saving of the module for consistent
|
|
12
|
+
computation.
|
|
13
|
+
|
|
14
|
+
This module can also be converted to TorchScript using torch.jit.script for use
|
|
15
|
+
outside of Python.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(self, num_vals: int) -> None:
|
|
19
|
+
"""Initializes the image_ECF module.
|
|
20
|
+
|
|
21
|
+
The initialized module is designed to compute the ECF of a 2D image, discretized by sampling num_vals values.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
num_vals: The number of values to discretize the ECF over.
|
|
25
|
+
"""
|
|
26
|
+
super().__init__()
|
|
27
|
+
self.num_vals: int = int(num_vals)
|
|
28
|
+
|
|
29
|
+
@staticmethod
|
|
30
|
+
def cell_values_2D(arr: torch.Tensor) -> List[torch.Tensor]:
|
|
31
|
+
"""
|
|
32
|
+
Creates a cubical complex with a function on its cells from a 2D tensor.
|
|
33
|
+
The structure of the cubical complex is ignored with only the function values on the cells
|
|
34
|
+
being recorded.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
arr (torch.Tensor): A 2D tensor with values between 0 and 1.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
vertex_values (torch.Tensor): A 1D tensor containing the function values of each vertex.
|
|
41
|
+
edge_values (torch.Tensor): A 1D tensor containing the function values of each edge.
|
|
42
|
+
square_values (torch.Tensor): A 1D tensor containing the function values of each square.
|
|
43
|
+
"""
|
|
44
|
+
arr = arr.float()
|
|
45
|
+
|
|
46
|
+
vertex_values = arr.reshape(-1)
|
|
47
|
+
|
|
48
|
+
x_edge_values = torch.maximum(arr[1:, :], arr[:-1, :])
|
|
49
|
+
y_edge_values = torch.maximum(arr[:, 1:], arr[:, :-1])
|
|
50
|
+
edge_values = torch.cat([
|
|
51
|
+
x_edge_values.reshape(-1),
|
|
52
|
+
y_edge_values.reshape(-1)
|
|
53
|
+
], dim=0)
|
|
54
|
+
|
|
55
|
+
square_values = torch.maximum(y_edge_values[1:, :], y_edge_values[:-1, :]).reshape(-1)
|
|
56
|
+
|
|
57
|
+
return [vertex_values, edge_values, square_values]
|
|
58
|
+
|
|
59
|
+
def forward(self, img_arr: torch.Tensor) -> torch.Tensor:
|
|
60
|
+
"""
|
|
61
|
+
Calculates a discretization of the ECF of a 2D image.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
img_arr (torch.Tensor): a 2D tensor with values between 0 and 1.
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
ecf (torch.Tensor): A 1D tensor of shape (self.num_vals) containing the ECF.
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
device = img_arr.device
|
|
71
|
+
n = self.num_vals
|
|
72
|
+
vertex_values, edge_values, square_values = self.cell_values_2D(img_arr)
|
|
73
|
+
|
|
74
|
+
vertex_indices = torch.ceil(vertex_values * (n-1)).long()
|
|
75
|
+
edge_indices = torch.ceil(edge_values * (n-1)).long()
|
|
76
|
+
square_indices = torch.ceil(square_values * (n-1)).long()
|
|
77
|
+
|
|
78
|
+
diff_ecf = torch.zeros(n, dtype=torch.int32, device=device)
|
|
79
|
+
|
|
80
|
+
# Add the contribution of the vertices
|
|
81
|
+
diff_ecf.scatter_add_(
|
|
82
|
+
0,
|
|
83
|
+
vertex_indices,
|
|
84
|
+
torch.ones_like(vertex_indices, dtype=torch.int32)
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
# Add the contribution of the edges
|
|
88
|
+
diff_ecf.scatter_add_(
|
|
89
|
+
0,
|
|
90
|
+
edge_indices,
|
|
91
|
+
-1 * torch.ones_like(edge_indices, dtype=torch.int32)
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
# Add the contribution of the squares
|
|
95
|
+
diff_ecf.scatter_add_(
|
|
96
|
+
0,
|
|
97
|
+
square_indices,
|
|
98
|
+
torch.ones_like(square_indices, dtype=torch.int32)
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
return torch.cumsum(diff_ecf, dim=0)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class Image_ECF_3D(torch.nn.Module):
|
|
105
|
+
"""A torch module for computing the ECF of a 3D image filtered by pixel intensity.
|
|
106
|
+
|
|
107
|
+
This module may be used just for computing the ECF of images, or used as a layer in a neural network.
|
|
108
|
+
Internally, the module stores the number of values used for sampling, so repeated forward calls
|
|
109
|
+
do not require this parameters to be passed in, and allow streamlined loading/saving of the module for consistent
|
|
110
|
+
computation.
|
|
111
|
+
|
|
112
|
+
This module can also be converted to TorchScript using torch.jit.script for use
|
|
113
|
+
outside of Python.
|
|
114
|
+
"""
|
|
115
|
+
|
|
116
|
+
def __init__(self, num_vals: int) -> None:
|
|
117
|
+
"""Initializes the image_ECF module.
|
|
118
|
+
|
|
119
|
+
The initialized module is designed to compute the ECF of a 3D image, discretized by sampling num_vals values.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
num_vals: The number of values to discretize the ECF over.
|
|
123
|
+
"""
|
|
124
|
+
super().__init__()
|
|
125
|
+
self.num_vals: int = int(num_vals)
|
|
126
|
+
|
|
127
|
+
@staticmethod
|
|
128
|
+
def cell_values_3D(arr: torch.Tensor) -> List[torch.Tensor]:
|
|
129
|
+
"""
|
|
130
|
+
Creates a cubical complex with a function on its cells from a 3D tensor.
|
|
131
|
+
The structure of the cubical complex is ignored with only the function values on the cells
|
|
132
|
+
being recorded.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
arr (torch.Tensor): A 3D tensor with values between 0 and 1.
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
vertex_values (torch.Tensor): A 1D tensor containing the function values of each vertex.
|
|
139
|
+
edge_values (torch.Tensor): A 1D tensor containing the function values of each edge.
|
|
140
|
+
square_values (torch.Tensor): A 1D tensor containing the function values of each square.
|
|
141
|
+
cube_values (torch.Tensor): A 1D tensor containing the function values of each cube.
|
|
142
|
+
"""
|
|
143
|
+
arr = arr.float()
|
|
144
|
+
|
|
145
|
+
vertex_values = arr.reshape(-1)
|
|
146
|
+
|
|
147
|
+
x_edge_values = torch.maximum(arr[1:, ...], arr[:-1, ...])
|
|
148
|
+
y_edge_values = torch.maximum(arr[:, 1:, :], arr[:, :-1, :])
|
|
149
|
+
z_edge_values = torch.maximum(arr[..., 1:], arr[..., :-1])
|
|
150
|
+
edge_values = torch.cat([
|
|
151
|
+
x_edge_values.reshape(-1),
|
|
152
|
+
y_edge_values.reshape(-1),
|
|
153
|
+
z_edge_values.reshape(-1)
|
|
154
|
+
], dim=0)
|
|
155
|
+
|
|
156
|
+
x_square_values = torch.maximum(y_edge_values[..., 1:], y_edge_values[..., :-1])
|
|
157
|
+
y_square_values = torch.maximum(z_edge_values[1:, ...], z_edge_values[:-1, ...])
|
|
158
|
+
z_square_values = torch.maximum(x_edge_values[:, 1:, :], x_edge_values[:, :-1, :])
|
|
159
|
+
square_values = torch.cat([
|
|
160
|
+
x_square_values.reshape(-1),
|
|
161
|
+
y_square_values.reshape(-1),
|
|
162
|
+
z_square_values.reshape(-1)
|
|
163
|
+
], dim=0)
|
|
164
|
+
|
|
165
|
+
cube_values = torch.maximum(x_square_values[1:, ...], x_square_values[:-1, ...]).reshape(-1)
|
|
166
|
+
|
|
167
|
+
return [vertex_values, edge_values, square_values, cube_values]
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def forward(self, img_arr: torch.Tensor) -> torch.Tensor:
|
|
171
|
+
"""
|
|
172
|
+
Calculates a discretization of the ECF of a 3D image.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
img_arr (torch.Tensor): A 3D tensor with values between 0 and 1.
|
|
176
|
+
|
|
177
|
+
Returns:
|
|
178
|
+
ecf (torch.Tensor): A 1D tensor of shape (self.num_vals) containing the sublevel set ECF.
|
|
179
|
+
"""
|
|
180
|
+
|
|
181
|
+
device = img_arr.device
|
|
182
|
+
n = self.num_vals
|
|
183
|
+
vertex_values, edge_values, square_values, cube_values = self.cell_values_3D(img_arr)
|
|
184
|
+
|
|
185
|
+
vertex_indices = torch.ceil(vertex_values * (n-1)).long()
|
|
186
|
+
edge_indices = torch.ceil(edge_values * (n-1)).long()
|
|
187
|
+
square_indices = torch.ceil(square_values * (n-1)).long()
|
|
188
|
+
cube_indices = torch.ceil(cube_values * (n-1)).long()
|
|
189
|
+
|
|
190
|
+
diff_ecf = torch.zeros(n, dtype=torch.int32, device=device)
|
|
191
|
+
|
|
192
|
+
# Add the contribution of the vertices
|
|
193
|
+
diff_ecf.scatter_add_(
|
|
194
|
+
0,
|
|
195
|
+
vertex_indices,
|
|
196
|
+
torch.ones_like(vertex_indices, dtype=torch.int32)
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
# Add the contribution of the edges
|
|
200
|
+
diff_ecf.scatter_add_(
|
|
201
|
+
0,
|
|
202
|
+
edge_indices,
|
|
203
|
+
-1 * torch.ones_like(edge_indices, dtype=torch.int32)
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
# Add the contribution of the squares
|
|
207
|
+
diff_ecf.scatter_add_(
|
|
208
|
+
0,
|
|
209
|
+
square_indices,
|
|
210
|
+
torch.ones_like(square_indices, dtype=torch.int32)
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
# Add the contribution of the cubes
|
|
214
|
+
diff_ecf.scatter_add_(
|
|
215
|
+
0,
|
|
216
|
+
cube_indices,
|
|
217
|
+
-1 * torch.ones_like(cube_indices, dtype=torch.int32)
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
return torch.cumsum(diff_ecf, dim=0)
|
|
File without changes
|
|
@@ -0,0 +1,242 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
import torch
|
|
3
|
+
import torchvision.transforms as transforms
|
|
4
|
+
from pyect import Complex
|
|
5
|
+
from PIL import Image
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def image_to_grayscale_tensor(image_path: str, device: torch.device) -> torch.Tensor:
|
|
9
|
+
# Open the image using PIL
|
|
10
|
+
image = Image.open(image_path)
|
|
11
|
+
# Convert the image to grayscale (mode 'L')
|
|
12
|
+
grayscale_image = image.convert("L")
|
|
13
|
+
# Convert the grayscale image to a tensor with values in [0,1]
|
|
14
|
+
tensor = transforms.ToTensor()(grayscale_image).squeeze(dim=0)
|
|
15
|
+
# The resulting tensor will have shape (H, W)
|
|
16
|
+
return tensor.to(device)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def weighted_freudenthal(
|
|
20
|
+
img_arr: torch.Tensor, device: Optional[torch.device] = None
|
|
21
|
+
) -> Complex:
|
|
22
|
+
"""
|
|
23
|
+
Creates the weighted Freudenthal complex of an image array using a max function extension.
|
|
24
|
+
Discards edges and triangles that have a vertex with a zero weight.
|
|
25
|
+
By default, the device of the input tensor is used unless a different device is specified.
|
|
26
|
+
|
|
27
|
+
The vertices are a (h*w, 2) tensor with recentered pixel coordinates.
|
|
28
|
+
The vertex weights are a (h*w,) tensor containing the pixel intensities.
|
|
29
|
+
The edges are a (num_valid_edges, 2) tensor of vertex indices.
|
|
30
|
+
The edge weights are a (num_valid_edges,) tensor with the maximum weight on the edge.
|
|
31
|
+
The triangles are a (num_valid_triangles, 3) tensor of vertex indices.
|
|
32
|
+
The triangle weights are a (num_valid_triangles,) tensor with the maximum weight on the triangle.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
img_arr (torch.Tensor): A grayscale image of shape (h, w).
|
|
36
|
+
device (torch.device, optional): The device to create tensors on.
|
|
37
|
+
If None, the device of the input tensor is used.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
Complex: A complex containing the weighted vertices, weighted edges, and weighted triangles.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
device = img_arr.device if device is None else device
|
|
44
|
+
img_arr = img_arr.float().to(device)
|
|
45
|
+
h, w = img_arr.shape
|
|
46
|
+
|
|
47
|
+
# Create a mask of the nonzero pixels
|
|
48
|
+
img_mask = img_arr != 0
|
|
49
|
+
|
|
50
|
+
# Indices of nonzero pixels (vertices)
|
|
51
|
+
nonzero_vertices = torch.nonzero(img_mask, as_tuple=True)
|
|
52
|
+
|
|
53
|
+
# Enumerate the nonzero vertices in the index array with all other values set to 0
|
|
54
|
+
vertex_numbers = torch.zeros_like(img_arr, dtype=torch.int64, device=device)
|
|
55
|
+
vertex_numbers[nonzero_vertices] = torch.arange(
|
|
56
|
+
nonzero_vertices[0].size(0), dtype=torch.int64, device=device
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
# Construct the vertex coords and weights
|
|
60
|
+
vertex_coords = torch.stack([
|
|
61
|
+
nonzero_vertices[1] - (w - 1) / 2.0,
|
|
62
|
+
(h - 1) / 2.0 - nonzero_vertices[0]
|
|
63
|
+
], dim=1)
|
|
64
|
+
vertex_weights = img_arr[nonzero_vertices]
|
|
65
|
+
vertices = (vertex_coords, vertex_weights)
|
|
66
|
+
|
|
67
|
+
### Horizontal Edges
|
|
68
|
+
# Remove the first and last columns of img_mask and check where the resulting arrays are both nonzero
|
|
69
|
+
horizontal_edge_mask = img_mask[:, :-1] & img_mask[:, 1:]
|
|
70
|
+
horizontal_edge_indices = torch.nonzero(horizontal_edge_mask, as_tuple=True)
|
|
71
|
+
|
|
72
|
+
# Get the vertex numbers of the endpoints of each horizontal edge
|
|
73
|
+
horizontal_edge_vertices = torch.stack([
|
|
74
|
+
vertex_numbers[horizontal_edge_indices],
|
|
75
|
+
vertex_numbers[:, 1:][horizontal_edge_indices]
|
|
76
|
+
], dim=1)
|
|
77
|
+
horizontal_edge_weights = vertex_weights[horizontal_edge_vertices].amax(dim=1)
|
|
78
|
+
|
|
79
|
+
### Vertical Edges
|
|
80
|
+
# Remove the first and last rows of img_mask and check where the resulting arrays are both nonzero
|
|
81
|
+
vertical_edge_mask = img_mask[:-1, :] & img_mask[1:, :]
|
|
82
|
+
vertical_edge_indices = torch.nonzero(vertical_edge_mask, as_tuple=True)
|
|
83
|
+
|
|
84
|
+
# Get the vertex numbers of the endpoints of each vertical edge
|
|
85
|
+
vertical_edge_vertices = torch.stack([
|
|
86
|
+
vertex_numbers[vertical_edge_indices],
|
|
87
|
+
vertex_numbers[1:, :][vertical_edge_indices]
|
|
88
|
+
], dim=1)
|
|
89
|
+
vertical_edge_weights = vertex_weights[vertical_edge_vertices].amax(dim=1)
|
|
90
|
+
|
|
91
|
+
### Diagonal Edges
|
|
92
|
+
diagonal_edge_mask = img_mask[:-1, :-1] & img_mask[1:, 1:]
|
|
93
|
+
diagonal_edge_indices = torch.nonzero(diagonal_edge_mask, as_tuple=True)
|
|
94
|
+
diagonal_edge_vertices = torch.stack([
|
|
95
|
+
vertex_numbers[diagonal_edge_indices],
|
|
96
|
+
vertex_numbers[1:, 1:][diagonal_edge_indices]
|
|
97
|
+
], dim=1)
|
|
98
|
+
diagonal_edge_weights = vertex_weights[diagonal_edge_vertices].amax(dim=1)
|
|
99
|
+
|
|
100
|
+
# Concatenate the horizontal, vertical, and diagonal edges
|
|
101
|
+
edge_vertices = torch.cat([
|
|
102
|
+
horizontal_edge_vertices,
|
|
103
|
+
vertical_edge_vertices,
|
|
104
|
+
diagonal_edge_vertices
|
|
105
|
+
], dim=0)
|
|
106
|
+
edge_weights = torch.cat([
|
|
107
|
+
horizontal_edge_weights,
|
|
108
|
+
vertical_edge_weights,
|
|
109
|
+
diagonal_edge_weights
|
|
110
|
+
], dim=0)
|
|
111
|
+
edges = (edge_vertices, edge_weights)
|
|
112
|
+
|
|
113
|
+
### Upper Triangles
|
|
114
|
+
upper_triangle_mask = img_mask[:-1, :-1] & img_mask[:-1, 1:] & img_mask[1:, 1:]
|
|
115
|
+
upper_triangle_indices = torch.nonzero(upper_triangle_mask, as_tuple=True)
|
|
116
|
+
upper_triangle_vertices = torch.stack([
|
|
117
|
+
vertex_numbers[upper_triangle_indices],
|
|
118
|
+
vertex_numbers[:, 1:][upper_triangle_indices],
|
|
119
|
+
vertex_numbers[1:, 1:][upper_triangle_indices]
|
|
120
|
+
], dim=1)
|
|
121
|
+
upper_triangle_weights = vertex_weights[upper_triangle_vertices].amax(dim=1)
|
|
122
|
+
|
|
123
|
+
### Lower Triangles
|
|
124
|
+
lower_triangle_mask = img_mask[:-1, :-1] & img_mask[1:, :-1] & img_mask[1:, 1:]
|
|
125
|
+
lower_triangle_indices = torch.nonzero(lower_triangle_mask, as_tuple=True)
|
|
126
|
+
lower_triangle_vertices = torch.stack([
|
|
127
|
+
vertex_numbers[lower_triangle_indices],
|
|
128
|
+
vertex_numbers[1:, :][lower_triangle_indices],
|
|
129
|
+
vertex_numbers[1:, 1:][lower_triangle_indices]
|
|
130
|
+
], dim=1)
|
|
131
|
+
lower_triangle_weights = vertex_weights[lower_triangle_vertices].amax(dim=1)
|
|
132
|
+
|
|
133
|
+
### Concatenate the upper and lower triangles
|
|
134
|
+
triangle_vertices = torch.cat([
|
|
135
|
+
upper_triangle_vertices,
|
|
136
|
+
lower_triangle_vertices
|
|
137
|
+
], dim=0)
|
|
138
|
+
triangle_weights = torch.cat([
|
|
139
|
+
upper_triangle_weights,
|
|
140
|
+
lower_triangle_weights
|
|
141
|
+
], dim=0)
|
|
142
|
+
triangles = (triangle_vertices, triangle_weights)
|
|
143
|
+
|
|
144
|
+
return Complex(vertices, edges, triangles, device=device)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def weighted_cubical(
|
|
148
|
+
img_arr: torch.Tensor, device: Optional[torch.device] = None
|
|
149
|
+
) -> Complex:
|
|
150
|
+
"""
|
|
151
|
+
Creates the weighted cubical complex of an image array.
|
|
152
|
+
Discards edges and squares that have a vertex with zero weight.
|
|
153
|
+
|
|
154
|
+
The vertices are a (h*w, 2) tensor with recentered pixel coordinates.
|
|
155
|
+
The vertex weights are a (h*w,) tensor containing the pixel intensities.
|
|
156
|
+
The edges are a (num_valid_edges, 2) tensor of vertex indices.
|
|
157
|
+
The edge weights are a (num_valid_edges,) tensor with the maximum weight on the edge.
|
|
158
|
+
The squares are a (num_valid_squares, 4) tensor of vertex indices.
|
|
159
|
+
The square weights are a (num_valid_squares,) tensor with the maximum weight on
|
|
160
|
+
the square.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
img_arr (torch.Tensor): A grayscale image of shape (h, w).
|
|
164
|
+
device (torch.device, optional): The device to create tensors on.
|
|
165
|
+
If None, the device of the input tensor is used.
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
Complex: A complex containing the weighted vertices, weighted edges, and weighted squares.
|
|
169
|
+
"""
|
|
170
|
+
|
|
171
|
+
device = img_arr.device if device is None else device
|
|
172
|
+
img_arr = img_arr.float().to(device)
|
|
173
|
+
h, w = img_arr.shape
|
|
174
|
+
|
|
175
|
+
# Create a mask of the nonzero pixels
|
|
176
|
+
img_mask = img_arr != 0
|
|
177
|
+
|
|
178
|
+
# Indices of nonzero pixels (vertices)
|
|
179
|
+
nonzero_vertices = torch.nonzero(img_mask, as_tuple=True)
|
|
180
|
+
|
|
181
|
+
# Create an array enumerating the nonzero vertices with all other values 0
|
|
182
|
+
vertex_numbers = torch.zeros_like(img_arr, dtype=torch.int64, device=device)
|
|
183
|
+
vertex_numbers[nonzero_vertices] = torch.arange(
|
|
184
|
+
nonzero_vertices[0].size(0), dtype=torch.int64, device=device
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
# Construct the vertex coords and weights
|
|
188
|
+
vertex_coords = torch.stack([
|
|
189
|
+
nonzero_vertices[1] - (w - 1) / 2.0,
|
|
190
|
+
(h - 1) / 2.0 - nonzero_vertices[0]
|
|
191
|
+
], dim=1)
|
|
192
|
+
vertex_weights = img_arr[nonzero_vertices]
|
|
193
|
+
vertices = (vertex_coords, vertex_weights)
|
|
194
|
+
|
|
195
|
+
### Horizontal Edges
|
|
196
|
+
# Remove the first and last columns of img_mask and check where the resulting arrays are both nonzero
|
|
197
|
+
horizontal_edge_mask = img_mask[:, :-1] & img_mask[:, 1:]
|
|
198
|
+
horizontal_edge_indices = torch.nonzero(horizontal_edge_mask, as_tuple=True)
|
|
199
|
+
|
|
200
|
+
# Get the vertex numbers of the endpoints of each horizontal edge
|
|
201
|
+
horizontal_edge_vertices = torch.stack([
|
|
202
|
+
vertex_numbers[horizontal_edge_indices],
|
|
203
|
+
vertex_numbers[:, 1:][horizontal_edge_indices]
|
|
204
|
+
], dim=1)
|
|
205
|
+
horizontal_edge_weights = vertex_weights[horizontal_edge_vertices].amax(dim=1)
|
|
206
|
+
|
|
207
|
+
### Vertical Edges
|
|
208
|
+
# Remove the first and last rows of img_mask and check where the resulting arrays are both nonzero
|
|
209
|
+
vertical_edge_mask = img_mask[:-1, :] & img_mask[1:, :]
|
|
210
|
+
vertical_edge_indices = torch.nonzero(vertical_edge_mask, as_tuple=True)
|
|
211
|
+
|
|
212
|
+
# Get the vertex numbers of the endpoints of each vertical edge
|
|
213
|
+
vertical_edge_vertices = torch.stack([
|
|
214
|
+
vertex_numbers[vertical_edge_indices],
|
|
215
|
+
vertex_numbers[1:, :][vertical_edge_indices]
|
|
216
|
+
], dim=1)
|
|
217
|
+
vertical_edge_weights = vertex_weights[vertical_edge_vertices].amax(dim=1)
|
|
218
|
+
|
|
219
|
+
# Concatenate the horizontal and vertical edges
|
|
220
|
+
edge_vertices = torch.cat([
|
|
221
|
+
horizontal_edge_vertices,
|
|
222
|
+
vertical_edge_vertices
|
|
223
|
+
], dim=0)
|
|
224
|
+
edge_weights = torch.cat([
|
|
225
|
+
horizontal_edge_weights,
|
|
226
|
+
vertical_edge_weights
|
|
227
|
+
], dim=0)
|
|
228
|
+
edges = (edge_vertices, edge_weights)
|
|
229
|
+
|
|
230
|
+
###Squares
|
|
231
|
+
square_mask = horizontal_edge_mask[:-1, :] & horizontal_edge_mask[1:, :]
|
|
232
|
+
square_indices = torch.nonzero(square_mask, as_tuple=True)
|
|
233
|
+
square_vertices = torch.stack([
|
|
234
|
+
vertex_numbers[square_indices],
|
|
235
|
+
vertex_numbers[1:, :][square_indices],
|
|
236
|
+
vertex_numbers[:, 1:][square_indices],
|
|
237
|
+
vertex_numbers[1:, 1:][square_indices]
|
|
238
|
+
], dim=1)
|
|
239
|
+
square_weights = vertex_weights[square_vertices].amax(dim=1)
|
|
240
|
+
squares = (square_vertices, square_weights)
|
|
241
|
+
|
|
242
|
+
return Complex(vertices, edges, squares, n_type="cubical", device=device)
|
|
@@ -0,0 +1,204 @@
|
|
|
1
|
+
"""Tools for working with simplicial complexes.
|
|
2
|
+
|
|
3
|
+
The Complex class is a collection of simplices, each of which is represented by a
|
|
4
|
+
tensor of coordinates and a tensor of weights.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import Tuple, Optional
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
import warnings
|
|
11
|
+
import numpy.typing as npt
|
|
12
|
+
|
|
13
|
+
from .dtypes import COORDS_DTYPE, INDICES_DTYPE, WEIGHTS_DTYPE
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class Complex:
|
|
17
|
+
"""A simplicial complex of arbitrary dimension.
|
|
18
|
+
|
|
19
|
+
The representation is as a collection of simplices (or cubical cells) using tensors.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
*args: Tuple[torch.Tensor, torch.Tensor],
|
|
25
|
+
vertex_dtype: torch.dtype = COORDS_DTYPE,
|
|
26
|
+
index_dtype: torch.dtype = INDICES_DTYPE,
|
|
27
|
+
weights_dtype: torch.dtype = WEIGHTS_DTYPE,
|
|
28
|
+
device: Optional[torch.device] = None,
|
|
29
|
+
n_type: str = "simplicial",
|
|
30
|
+
) -> None:
|
|
31
|
+
"""Initializes a complex.
|
|
32
|
+
|
|
33
|
+
All tensors are cast to the given types.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
*args: A variable number of tuples, each containing the simplices of a given
|
|
37
|
+
dimension. Each tuple should contain two tensors.
|
|
38
|
+
The first tensor contains the coordinates of the simplices
|
|
39
|
+
The second tensor contains the weights of the simplices.
|
|
40
|
+
|
|
41
|
+
The first tuple should contain the vertices of the complex, and
|
|
42
|
+
therefore must be a tensor of shape [num_vertices, d].
|
|
43
|
+
|
|
44
|
+
Any following tuples should contain indices into the vertices tensor,
|
|
45
|
+
and therefore must be a tensor of shape [num_simplices, k], where k+1 is the
|
|
46
|
+
dimension of the simplex.
|
|
47
|
+
|
|
48
|
+
vertex_dtype: The data type to use for the vertex coordinates.
|
|
49
|
+
index_dtype: The data type to use for the simplex indices.
|
|
50
|
+
weights_dtype: The data type to use for the simplex weights.
|
|
51
|
+
device: The device to use for the tensors.
|
|
52
|
+
n_type: The type of complex. Currently only "simplicial" and "cubical"
|
|
53
|
+
are supported.
|
|
54
|
+
"""
|
|
55
|
+
# Verify the dimensions of the simplices, and raise a UserError if
|
|
56
|
+
# there is a mismatch.
|
|
57
|
+
self._validate_dimensions(*args, n_type=n_type)
|
|
58
|
+
|
|
59
|
+
# Call .to on each tensor to cast to the given type and device.
|
|
60
|
+
types = [vertex_dtype] + [index_dtype] * (len(args) - 1)
|
|
61
|
+
self.dimensions = tuple(
|
|
62
|
+
(
|
|
63
|
+
(
|
|
64
|
+
coords.to(dtype=types[dim], device=device),
|
|
65
|
+
weights.to(dtype=weights_dtype, device=device),
|
|
66
|
+
)
|
|
67
|
+
)
|
|
68
|
+
for dim, (coords, weights) in enumerate(args)
|
|
69
|
+
)
|
|
70
|
+
self.n_type = n_type
|
|
71
|
+
|
|
72
|
+
@staticmethod
|
|
73
|
+
def from_numpy(
|
|
74
|
+
*args: Tuple[npt.NDArray, npt.NDArray],
|
|
75
|
+
vertex_dtype: torch.dtype = COORDS_DTYPE,
|
|
76
|
+
index_dtype: torch.dtype = INDICES_DTYPE,
|
|
77
|
+
weights_dtype: torch.dtype = WEIGHTS_DTYPE,
|
|
78
|
+
device: Optional[torch.device] = None,
|
|
79
|
+
n_type: str = "simplicial",
|
|
80
|
+
) -> "Complex":
|
|
81
|
+
"""Initializes a simplicial complex from numpy arrays.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
*args: A variable number of tuples, each containing the simplices of a given
|
|
85
|
+
dimension. Each tuple should contain two numpy arrays.
|
|
86
|
+
The first array contains the coordinates of the simplices
|
|
87
|
+
The second array contains the weights of the simplices.
|
|
88
|
+
|
|
89
|
+
The first tuple should contain the vertices of the complex, and
|
|
90
|
+
therefore must be a tensor of shape [num_vertices, d].
|
|
91
|
+
|
|
92
|
+
Any following tuples should contain indices into the vertices tensor,
|
|
93
|
+
and therefore must be a tensor of shape [num_simplices, k], where k+1 is the
|
|
94
|
+
dimension of the simplex.
|
|
95
|
+
|
|
96
|
+
vertex_dtype: The data type to use for the vertex coordinates.
|
|
97
|
+
index_dtype: The data type to use for the simplex indices.
|
|
98
|
+
weights_dtype: The data type to use for the simplex weights.
|
|
99
|
+
device:
|
|
100
|
+
The device to use for the tensors.
|
|
101
|
+
n_type: The type of the simplicial complex. Currently only "simplicial" and "cubical"
|
|
102
|
+
are supported.
|
|
103
|
+
|
|
104
|
+
"""
|
|
105
|
+
if device is None:
|
|
106
|
+
device = (
|
|
107
|
+
torch.device("cuda")
|
|
108
|
+
if torch.cuda.is_available()
|
|
109
|
+
else torch.device("cpu")
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
typematch = [vertex_dtype] + [index_dtype] * (len(args) - 1)
|
|
113
|
+
dimensions = tuple(
|
|
114
|
+
(
|
|
115
|
+
torch.as_tensor(coords, device=device, dtype=typematch[i]),
|
|
116
|
+
torch.as_tensor(weights, device=device, dtype=weights_dtype),
|
|
117
|
+
)
|
|
118
|
+
for i, (coords, weights) in enumerate(args)
|
|
119
|
+
)
|
|
120
|
+
return Complex(*dimensions, device=device, n_type=n_type)
|
|
121
|
+
|
|
122
|
+
def to(self, device: torch.device) -> "Complex":
|
|
123
|
+
"""Moves the complex to the given device."""
|
|
124
|
+
return Complex(*self.dimensions, device=device, n_type=self.n_type)
|
|
125
|
+
|
|
126
|
+
def __getitem__(self, dim: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
127
|
+
"""Returns the simplices of the given dimension."""
|
|
128
|
+
return self.dimensions[dim]
|
|
129
|
+
|
|
130
|
+
def get_coords(self, dim: int) -> torch.Tensor:
|
|
131
|
+
"""Returns the coordinates of the simplices of the given dimension."""
|
|
132
|
+
return self.dimensions[dim][0]
|
|
133
|
+
|
|
134
|
+
def get_weights(self, dim: int) -> torch.Tensor:
|
|
135
|
+
"""Returns the weights of the simplices of the given dimension."""
|
|
136
|
+
return self.dimensions[dim][1]
|
|
137
|
+
|
|
138
|
+
def top_dim(self) -> int:
|
|
139
|
+
"""Returns the top dimension of the complex."""
|
|
140
|
+
return len(self) - 1
|
|
141
|
+
|
|
142
|
+
def __len__(self) -> int:
|
|
143
|
+
"""Returns the number of dimensions in the complex."""
|
|
144
|
+
return len(self.dimensions)
|
|
145
|
+
|
|
146
|
+
def space_dim(self) -> int:
|
|
147
|
+
"""Returns the dimension of the space the complex is embedded in."""
|
|
148
|
+
return self.dimensions[0][0].shape[1]
|
|
149
|
+
|
|
150
|
+
def center_(self) -> "Complex":
|
|
151
|
+
"""
|
|
152
|
+
Re-center the complex in-place so that the average vertex coordinate is at the origin.
|
|
153
|
+
"""
|
|
154
|
+
if len(self.dimensions) == 0:
|
|
155
|
+
return self
|
|
156
|
+
|
|
157
|
+
v_coords, v_weights = self.dimensions[0]
|
|
158
|
+
if v_coords.numel() == 0:
|
|
159
|
+
return self
|
|
160
|
+
|
|
161
|
+
center = v_coords.mean(dim=0)
|
|
162
|
+
new_v_coords = (v_coords - center).contiguous()
|
|
163
|
+
|
|
164
|
+
dims: list[Tuple[torch.Tensor, torch.Tensor]] = list(self.dimensions)
|
|
165
|
+
dims[0] = (new_v_coords, v_weights)
|
|
166
|
+
self.dimensions = tuple(dims)
|
|
167
|
+
return self
|
|
168
|
+
|
|
169
|
+
@staticmethod
|
|
170
|
+
def _validate_dimensions(
|
|
171
|
+
*args: Tuple[torch.Tensor, torch.Tensor], n_type: str
|
|
172
|
+
) -> None:
|
|
173
|
+
for dim, simplex_list in enumerate(args):
|
|
174
|
+
if simplex_list[0].dim() != 2:
|
|
175
|
+
raise ValueError(
|
|
176
|
+
f"Dimension {dim} simplices must be a 2d tensor."
|
|
177
|
+
+ f" Got {simplex_list[0].dim()} dimensions."
|
|
178
|
+
)
|
|
179
|
+
if simplex_list[1].dim() != 1:
|
|
180
|
+
raise ValueError(
|
|
181
|
+
f"Dimension {dim} weights must be a 1d tensor."
|
|
182
|
+
+ f" Got {simplex_list[1].dim()} dimensions."
|
|
183
|
+
)
|
|
184
|
+
if simplex_list[0].shape[0] != simplex_list[1].shape[0]:
|
|
185
|
+
raise ValueError(
|
|
186
|
+
f"Dimension {dim} coordinates and weights must have the same number of simplices."
|
|
187
|
+
+ f" Got {simplex_list[0].shape[0]} simplices and {simplex_list[1].shape[0]} weights."
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
if dim > 0: # simplices, k > 0
|
|
191
|
+
if n_type == "simplicial":
|
|
192
|
+
if simplex_list[0].shape[1] != dim + 1:
|
|
193
|
+
raise ValueError(
|
|
194
|
+
f"Dimension {dim} simplices must have {dim + 1} columns."
|
|
195
|
+
+ f" Got {simplex_list[0].shape[1]} columns."
|
|
196
|
+
)
|
|
197
|
+
elif n_type == "cubical":
|
|
198
|
+
if simplex_list[0].shape[1] != 2 ** dim:
|
|
199
|
+
raise ValueError(
|
|
200
|
+
f"Dimension {dim} simplices must have {2 ** dim} columns."
|
|
201
|
+
+ f" Got {simplex_list[0].shape[1]} columns."
|
|
202
|
+
)
|
|
203
|
+
else: # warn that validation not implementod for n_type, but no error
|
|
204
|
+
warnings.warn(f"Validation not implemented for n_type {n_type}. Proceed with caution.")
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
""" For computing the WECFs of lower-star filtrations of
|
|
2
|
+
weighted simplicial/cubical complex with respect to a set of filter functions."""
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from typing import List, Tuple
|
|
6
|
+
|
|
7
|
+
def compute_wecfs(
|
|
8
|
+
complex_data: List[Tuple[torch.Tensor, torch.Tensor]],
|
|
9
|
+
num_vals: int
|
|
10
|
+
) -> torch.Tensor:
|
|
11
|
+
"""Calculates a discretization of the WECFs of a weighted complex with respect to a set of filter functions.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
complex_data: A weighted simplicial or cubical complex with a collection of filter functions,
|
|
15
|
+
represented as a list of pairs of tensors.
|
|
16
|
+
complex_data[0] = (filters, v_weights):
|
|
17
|
+
filters (torch.Tensor): A tensor of shape (k_0, m) where k_0 is the
|
|
18
|
+
number of vertices and m is the number of filter functions.
|
|
19
|
+
Each column contains the values of a filter function on the vertices.
|
|
20
|
+
|
|
21
|
+
v_weights (torch.Tensor): A tensor of shape (k_0). Values are the weights of the vertices.
|
|
22
|
+
|
|
23
|
+
for i > 0:
|
|
24
|
+
complex_data[i] = (simp_verts, simp_weights):
|
|
25
|
+
simp_verts (torch.Tensor): A tensor of shape (k_i, i+1) where k_i is the number of i-simplices.
|
|
26
|
+
Rows are the vertex sets of the i-simplices.
|
|
27
|
+
|
|
28
|
+
simp_weights (torch.Tensor): A tensor of shape (k_i). Values are the weights of the i-simplices.
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
wecfs (torch.Tensor): A 2d tensor of shape (m, num_vals)
|
|
32
|
+
containing the WECFs.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
filters = complex_data[0][0].float()
|
|
36
|
+
m = filters.size(dim=1)
|
|
37
|
+
device = filters.device
|
|
38
|
+
v_weights = complex_data[0][1].to(device=device, dtype=torch.float32)
|
|
39
|
+
|
|
40
|
+
expanded_v_weights = v_weights.unsqueeze(0).expand(m, -1) # Expand to shape (m, k_0)
|
|
41
|
+
|
|
42
|
+
# Map the values of the filter functions to indices in range(num_vals)
|
|
43
|
+
max_val = filters.abs().amax()
|
|
44
|
+
v_indices = torch.ceil(
|
|
45
|
+
(num_vals - 1) * (max_val + filters) / (2.0 * max_val)
|
|
46
|
+
).clamp(0, num_vals-1).long()
|
|
47
|
+
|
|
48
|
+
# Initialize the differentiated WECFs
|
|
49
|
+
diff_wecfs = torch.zeros((m, num_vals), dtype=torch.float32, device=device)
|
|
50
|
+
|
|
51
|
+
# Add the contribution of the vertices to the differentiated WECFs
|
|
52
|
+
diff_wecfs.scatter_add_(1, v_indices.T, expanded_v_weights)
|
|
53
|
+
|
|
54
|
+
for i in range(1, len(complex_data)):
|
|
55
|
+
simp_verts = complex_data[i][0].to(device=device, dtype=torch.long)
|
|
56
|
+
simp_weights = complex_data[i][1].to(device=device, dtype=torch.float32)
|
|
57
|
+
|
|
58
|
+
expanded_simp_weights = (-1) ** i * simp_weights.unsqueeze(0).expand(m, -1)
|
|
59
|
+
|
|
60
|
+
simp_indices = v_indices[simp_verts]
|
|
61
|
+
max_simp_indices = torch.amax(simp_indices, dim=1)
|
|
62
|
+
|
|
63
|
+
diff_wecfs.scatter_add_(1, max_simp_indices.T, expanded_simp_weights)
|
|
64
|
+
|
|
65
|
+
return torch.cumsum(diff_wecfs, dim=1)
|
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
"""For computing the WECT of a weighted geometric simplicial/cubical complex embedded in R^n."""
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from typing import List, Tuple
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class WECT(torch.nn.Module):
|
|
8
|
+
"""A torch module for computing the Weighted Euler Characteristic Transform (WECT) of a simplicial complex discretized over a grid.
|
|
9
|
+
|
|
10
|
+
This module may be used just for computing the WECT, or used as a layer in a neural network.
|
|
11
|
+
Internally, the module stores the directions and number of heights used for sampling, so repeated forward calls
|
|
12
|
+
do not require these parameters to be passed in, and allow streamlined loading/saving of the module for consistent
|
|
13
|
+
computation.
|
|
14
|
+
|
|
15
|
+
This module can also be converted to TorchScript using torch.jit.script for use
|
|
16
|
+
outside of Python.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(self, dirs: torch.Tensor, num_heights: int) -> None:
|
|
20
|
+
"""Initializes the WECT module.
|
|
21
|
+
|
|
22
|
+
The initialized module is designed to compute the WECT of a simplicial complex
|
|
23
|
+
embedded in R^[dirs.shape[1]], using dirs.shape[0] directions for sampling.
|
|
24
|
+
The discretization of the WECT is parameterized by num_heights distinct height values.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
dirs: An (d x n) tensor of directions to use for sampling.
|
|
28
|
+
num_heights: A constant tensor, with the number of distinct height
|
|
29
|
+
values to round to as an integer
|
|
30
|
+
"""
|
|
31
|
+
super().__init__()
|
|
32
|
+
dirs = torch.nn.functional.normalize(dirs, p=2, dim=1, eps=1e-12)
|
|
33
|
+
self.register_buffer("dirs", dirs)
|
|
34
|
+
self.num_heights: int = int(num_heights)
|
|
35
|
+
|
|
36
|
+
def _vertex_indices(
|
|
37
|
+
self,
|
|
38
|
+
vertex_coords: torch.Tensor,
|
|
39
|
+
) -> torch.Tensor:
|
|
40
|
+
"""Calculates the height values of each vertex and converts them to an index in range(num_heights).
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
vertex_coords (torch.Tensor): A tensor of shape (k_0, n) with rows representing the coordinates of the vertices.
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
torch.Tensor: A tensor of shape (k_0, d) with the height indices of each vertex in each direction.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
v_norms = torch.norm(vertex_coords, dim=1)
|
|
50
|
+
max_height = torch.amax(v_norms)
|
|
51
|
+
v_heights = torch.matmul(vertex_coords, self.dirs.T)
|
|
52
|
+
|
|
53
|
+
# The case where all vertices are at the origin
|
|
54
|
+
if max_height.item() == 0.0:
|
|
55
|
+
return torch.zeros((v_heights.size(0), self.dirs.size(0)), dtype=torch.long, device=self.dirs.device)
|
|
56
|
+
|
|
57
|
+
v_indices = torch.ceil(
|
|
58
|
+
(self.num_heights - 1) * (max_height + v_heights) / (2.0 * max_height)
|
|
59
|
+
).clamp(0, self.num_heights - 1).long()
|
|
60
|
+
|
|
61
|
+
return v_indices
|
|
62
|
+
|
|
63
|
+
def forward(
|
|
64
|
+
self,
|
|
65
|
+
complex_data: List[Tuple[torch.Tensor, torch.Tensor]],
|
|
66
|
+
) -> torch.Tensor:
|
|
67
|
+
"""Calculates a discretization of the WECT of a complex embedded in n-dimensional space.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
complex_data: A weighted simplicial or cubical complex, represented as a list of pairs of tensors.
|
|
71
|
+
complex_data[0] = (v_coords, v_weights):
|
|
72
|
+
v_coords (torch.Tensor): A tensor of shape (k_0, n) where k_0 is the number of vertices.
|
|
73
|
+
Rows are the coordinates of the vertices.
|
|
74
|
+
|
|
75
|
+
v_weights (torch.Tensor): A tensor of shape (k_0). Values are the weights of the vertices.
|
|
76
|
+
|
|
77
|
+
for i > 0:
|
|
78
|
+
complex_data[i] = (simp_verts, simp_weights):
|
|
79
|
+
simp_verts (torch.Tensor): A tensor of shape (k_i, i+1) where k_i is the number of i-simplices.
|
|
80
|
+
Rows are the vertex sets of the i-simplices.
|
|
81
|
+
|
|
82
|
+
simp_weights (torch.Tensor): A tensor of shape (k_i). Values are the weights of the i-simplices.
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
wect (torch.Tensor): A 2d tensor of shape (self.dirs.shape[0], self.num_heights)
|
|
86
|
+
containing the WECT.
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
d = self.dirs.size(dim=0)
|
|
90
|
+
h = self.num_heights
|
|
91
|
+
|
|
92
|
+
if h <= 0:
|
|
93
|
+
raise ValueError("num_heights must be positive.")
|
|
94
|
+
|
|
95
|
+
device = self.dirs.device
|
|
96
|
+
v_coords = complex_data[0][0].to(device=device, dtype=torch.float32)
|
|
97
|
+
v_weights = complex_data[0][1].to(device=device, dtype=torch.float32)
|
|
98
|
+
|
|
99
|
+
# Check for empty inputs
|
|
100
|
+
if v_coords.size(0) == 0:
|
|
101
|
+
return torch.zeros((d, h), dtype=torch.float32, device=device)
|
|
102
|
+
|
|
103
|
+
expanded_v_weights = v_weights.unsqueeze(0).expand(
|
|
104
|
+
d, -1
|
|
105
|
+
) # Expand to shape (d, k_0)
|
|
106
|
+
|
|
107
|
+
# Initialize the differentiated WECT
|
|
108
|
+
diff_wect = torch.zeros((d, h), dtype=torch.float32, device=device)
|
|
109
|
+
|
|
110
|
+
# Compute the height index of each vertex
|
|
111
|
+
v_indices = self._vertex_indices(v_coords)
|
|
112
|
+
|
|
113
|
+
# Add the contribution of the vertices to the differentiated WECT
|
|
114
|
+
diff_wect.scatter_add_(1, v_indices.T, expanded_v_weights)
|
|
115
|
+
|
|
116
|
+
for i in range(1, len(complex_data)):
|
|
117
|
+
simp_verts = complex_data[i][0].to(device=device, dtype=torch.long)
|
|
118
|
+
simp_weights = complex_data[i][1].to(device=device, dtype=torch.float32)
|
|
119
|
+
|
|
120
|
+
# Expand to shape (d, k_i)
|
|
121
|
+
expanded_simp_weights = (-1) ** i * simp_weights.unsqueeze(0).expand(d, -1)
|
|
122
|
+
|
|
123
|
+
# Compute the maximum index for each simplex's vertices
|
|
124
|
+
simp_indices = v_indices[simp_verts]
|
|
125
|
+
max_simp_indices = torch.amax(simp_indices, dim=1)
|
|
126
|
+
|
|
127
|
+
# Add the contribution of the i-simplices to the differentiated WECT
|
|
128
|
+
diff_wect.scatter_add_(1, max_simp_indices.T, expanded_simp_weights)
|
|
129
|
+
|
|
130
|
+
return torch.cumsum(diff_wect, dim=1)
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: pyect
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Generalized computation of the WECT using PyTorch.
|
|
5
|
+
Home-page: https://github.com/compTAG/pyECT
|
|
6
|
+
Author: Alex McCleary, Eli Quist, Jack Ruder, Jacob Sriraman
|
|
7
|
+
Author-email: eli.quist@student.montana.edu
|
|
8
|
+
License: MIT
|
|
9
|
+
Classifier: Programming Language :: Python :: 3
|
|
10
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
11
|
+
Classifier: Operating System :: OS Independent
|
|
12
|
+
Requires-Python: >=3.8
|
|
13
|
+
Description-Content-Type: text/markdown
|
|
14
|
+
License-File: LICENSE
|
|
15
|
+
Requires-Dist: torch
|
|
16
|
+
Requires-Dist: Pillow
|
|
17
|
+
Dynamic: author
|
|
18
|
+
Dynamic: author-email
|
|
19
|
+
Dynamic: classifier
|
|
20
|
+
Dynamic: description
|
|
21
|
+
Dynamic: description-content-type
|
|
22
|
+
Dynamic: home-page
|
|
23
|
+
Dynamic: license
|
|
24
|
+
Dynamic: license-file
|
|
25
|
+
Dynamic: requires-dist
|
|
26
|
+
Dynamic: requires-python
|
|
27
|
+
Dynamic: summary
|
|
28
|
+
|
|
29
|
+
# pyECT
|
|
30
|
+
|
|
31
|
+
The Weighted Euler Characteristic Transform (WECT) is a mathematical tool
|
|
32
|
+
used to analyze and summarize geometric and topological features of data.
|
|
33
|
+
This package provides an efficient and simple implementation of the WECT using
|
|
34
|
+
PyTorch.
|
|
35
|
+
|
|
36
|
+
This codebase accompanies the following paper (and should be cited if you use
|
|
37
|
+
this package):
|
|
38
|
+
|
|
39
|
+
```
|
|
40
|
+
TODO: Add Citation
|
|
41
|
+
```
|
|
42
|
+
|
|
43
|
+
## Installation
|
|
44
|
+
|
|
45
|
+
To install `pyECT`, use pip:
|
|
46
|
+
|
|
47
|
+
```bash
|
|
48
|
+
pip install pyect
|
|
49
|
+
```
|
|
50
|
+
|
|
51
|
+
## Usage
|
|
52
|
+
|
|
53
|
+
Here's a simple example of how to use `pyECT`:
|
|
54
|
+
|
|
55
|
+
```python
|
|
56
|
+
from pyect import WECT
|
|
57
|
+
|
|
58
|
+
# Example data and weight function
|
|
59
|
+
data = [...] # Replace with your data
|
|
60
|
+
weight_function = lambda x: x**2 # Replace with your weight function
|
|
61
|
+
|
|
62
|
+
# Compute the WECT
|
|
63
|
+
wect = WECT(data, weight_function)
|
|
64
|
+
result = wect.compute()
|
|
65
|
+
|
|
66
|
+
print("WECT result:", result)
|
|
67
|
+
```
|
|
68
|
+
|
|
69
|
+
For more detailed examples, please see the `/examples` directory.
|
|
70
|
+
|
|
71
|
+
## Contributing
|
|
72
|
+
|
|
73
|
+
Contributions are welcome! If you'd like to contribute, please fork the
|
|
74
|
+
repository and submit a pull request. For major changes, please open an issue
|
|
75
|
+
first to discuss what you'd like to change.
|
|
76
|
+
|
|
77
|
+
## License
|
|
78
|
+
|
|
79
|
+
This project is licensed under the MIT License. See the [LICENSE](LICENSE)
|
|
80
|
+
file for details.
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
LICENSE
|
|
2
|
+
README.md
|
|
3
|
+
pyproject.toml
|
|
4
|
+
setup.py
|
|
5
|
+
pyect/__init__.py
|
|
6
|
+
pyect/directions.py
|
|
7
|
+
pyect/dtypes.py
|
|
8
|
+
pyect/image_ecf.py
|
|
9
|
+
pyect/tensor_complex.py
|
|
10
|
+
pyect/wecfs.py
|
|
11
|
+
pyect/wect.py
|
|
12
|
+
pyect.egg-info/PKG-INFO
|
|
13
|
+
pyect.egg-info/SOURCES.txt
|
|
14
|
+
pyect.egg-info/dependency_links.txt
|
|
15
|
+
pyect.egg-info/requires.txt
|
|
16
|
+
pyect.egg-info/top_level.txt
|
|
17
|
+
pyect/preprocessing/__init__.py
|
|
18
|
+
pyect/preprocessing/image_processing.py
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
pyect
|
pyect-0.1.0/setup.cfg
ADDED
pyect-0.1.0/setup.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
from setuptools import setup, find_packages
|
|
2
|
+
|
|
3
|
+
setup(
|
|
4
|
+
name='pyect',
|
|
5
|
+
version='0.1.0',
|
|
6
|
+
author='Alex McCleary, Eli Quist, Jack Ruder, Jacob Sriraman',
|
|
7
|
+
author_email='eli.quist@student.montana.edu',
|
|
8
|
+
description='Generalized computation of the WECT using PyTorch.',
|
|
9
|
+
long_description=open('README.md').read(),
|
|
10
|
+
long_description_content_type='text/markdown',
|
|
11
|
+
url='https://github.com/compTAG/pyECT',
|
|
12
|
+
license='MIT',
|
|
13
|
+
packages=find_packages(),
|
|
14
|
+
install_requires=[
|
|
15
|
+
'torch',
|
|
16
|
+
'Pillow',
|
|
17
|
+
],
|
|
18
|
+
classifiers=[
|
|
19
|
+
'Programming Language :: Python :: 3',
|
|
20
|
+
'License :: OSI Approved :: MIT License',
|
|
21
|
+
'Operating System :: OS Independent',
|
|
22
|
+
],
|
|
23
|
+
python_requires='>=3.8',
|
|
24
|
+
)
|