pyect 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pyect/__init__.py ADDED
@@ -0,0 +1,9 @@
1
+ from .wect import WECT
2
+ from .tensor_complex import Complex
3
+ from .directions import sample_directions_2d, sample_directions_3d
4
+ from .image_ecf import Image_ECF_2D, Image_ECF_3D
5
+ from .preprocessing.image_processing import (
6
+ weighted_freudenthal,
7
+ weighted_cubical,
8
+ image_to_grayscale_tensor
9
+ )
pyect/directions.py ADDED
@@ -0,0 +1,28 @@
1
+ import math
2
+ import torch
3
+
4
+ golden_angle = math.pi * (3.0 - math.sqrt(5.0))
5
+
6
+ def sample_directions_2d(num_dirs: int, *, device=None):
7
+ """
8
+ Sample num_dirs directions evenly from S^1.
9
+ """
10
+
11
+ angles = 2 * math.pi * torch.arange(num_dirs, dtype=torch.float32, device=device) / num_dirs
12
+ directions = torch.stack([torch.cos(angles), torch.sin(angles)], dim=-1)
13
+ return directions.contiguous()
14
+
15
+ def sample_directions_3d(num_dirs: int, *, device=None):
16
+ """
17
+ Sample num_dirs directions from S^2 using the Fibonacci spiral method.
18
+ """
19
+
20
+ i = torch.arange(num_dirs, dtype=torch.float32, device=device)
21
+ theta = golden_angle * i
22
+ y = 1.0 - (2.0 * (i + 0.5) / num_dirs)
23
+ r = torch.sqrt(torch.clamp(1.0 - y * y, min=0.0))
24
+ x = torch.cos(theta) * r
25
+ z = torch.sin(theta) * r
26
+ directions = torch.stack([x, y, z], dim=-1)
27
+
28
+ return directions.contiguous()
pyect/dtypes.py ADDED
@@ -0,0 +1,7 @@
1
+ """Export types for the PyECT package."""
2
+
3
+ import torch
4
+
5
+ COORDS_DTYPE = torch.float32
6
+ INDICES_DTYPE = torch.int64
7
+ WEIGHTS_DTYPE = torch.float32
pyect/image_ecf.py ADDED
@@ -0,0 +1,220 @@
1
+ """For computing the ECF of 2- and 3-dimensional images filtered by pixel intensity"""
2
+
3
+ import torch
4
+ from typing import List
5
+
6
+ class Image_ECF_2D(torch.nn.Module):
7
+ """A torch module for computing the ECF of a 2D image filtered by pixel intensity.
8
+
9
+ This module may be used just for computing the ECF of images, or used as a layer in a neural network.
10
+ Internally, the module stores the number of values used for sampling, so repeated forward calls
11
+ do not require this parameters to be passed in, and allow streamlined loading/saving of the module for consistent
12
+ computation.
13
+
14
+ This module can also be converted to TorchScript using torch.jit.script for use
15
+ outside of Python.
16
+ """
17
+
18
+ def __init__(self, num_vals: int) -> None:
19
+ """Initializes the image_ECF module.
20
+
21
+ The initialized module is designed to compute the ECF of a 2D image, discretized by sampling num_vals values.
22
+
23
+ Args:
24
+ num_vals: The number of values to discretize the ECF over.
25
+ """
26
+ super().__init__()
27
+ self.num_vals: int = int(num_vals)
28
+
29
+ @staticmethod
30
+ def cell_values_2D(arr: torch.Tensor) -> List[torch.Tensor]:
31
+ """
32
+ Creates a cubical complex with a function on its cells from a 2D tensor.
33
+ The structure of the cubical complex is ignored with only the function values on the cells
34
+ being recorded.
35
+
36
+ Args:
37
+ arr (torch.Tensor): A 2D tensor with values between 0 and 1.
38
+
39
+ Returns:
40
+ vertex_values (torch.Tensor): A 1D tensor containing the function values of each vertex.
41
+ edge_values (torch.Tensor): A 1D tensor containing the function values of each edge.
42
+ square_values (torch.Tensor): A 1D tensor containing the function values of each square.
43
+ """
44
+ arr = arr.float()
45
+
46
+ vertex_values = arr.reshape(-1)
47
+
48
+ x_edge_values = torch.maximum(arr[1:, :], arr[:-1, :])
49
+ y_edge_values = torch.maximum(arr[:, 1:], arr[:, :-1])
50
+ edge_values = torch.cat([
51
+ x_edge_values.reshape(-1),
52
+ y_edge_values.reshape(-1)
53
+ ], dim=0)
54
+
55
+ square_values = torch.maximum(y_edge_values[1:, :], y_edge_values[:-1, :]).reshape(-1)
56
+
57
+ return [vertex_values, edge_values, square_values]
58
+
59
+ def forward(self, img_arr: torch.Tensor) -> torch.Tensor:
60
+ """
61
+ Calculates a discretization of the ECF of a 2D image.
62
+
63
+ Args:
64
+ img_arr (torch.Tensor): a 2D tensor with values between 0 and 1.
65
+
66
+ Returns:
67
+ ecf (torch.Tensor): A 1D tensor of shape (self.num_vals) containing the ECF.
68
+ """
69
+
70
+ device = img_arr.device
71
+ n = self.num_vals
72
+ vertex_values, edge_values, square_values = self.cell_values_2D(img_arr)
73
+
74
+ vertex_indices = torch.ceil(vertex_values * (n-1)).long()
75
+ edge_indices = torch.ceil(edge_values * (n-1)).long()
76
+ square_indices = torch.ceil(square_values * (n-1)).long()
77
+
78
+ diff_ecf = torch.zeros(n, dtype=torch.int32, device=device)
79
+
80
+ # Add the contribution of the vertices
81
+ diff_ecf.scatter_add_(
82
+ 0,
83
+ vertex_indices,
84
+ torch.ones_like(vertex_indices, dtype=torch.int32)
85
+ )
86
+
87
+ # Add the contribution of the edges
88
+ diff_ecf.scatter_add_(
89
+ 0,
90
+ edge_indices,
91
+ -1 * torch.ones_like(edge_indices, dtype=torch.int32)
92
+ )
93
+
94
+ # Add the contribution of the squares
95
+ diff_ecf.scatter_add_(
96
+ 0,
97
+ square_indices,
98
+ torch.ones_like(square_indices, dtype=torch.int32)
99
+ )
100
+
101
+ return torch.cumsum(diff_ecf, dim=0)
102
+
103
+
104
+ class Image_ECF_3D(torch.nn.Module):
105
+ """A torch module for computing the ECF of a 3D image filtered by pixel intensity.
106
+
107
+ This module may be used just for computing the ECF of images, or used as a layer in a neural network.
108
+ Internally, the module stores the number of values used for sampling, so repeated forward calls
109
+ do not require this parameters to be passed in, and allow streamlined loading/saving of the module for consistent
110
+ computation.
111
+
112
+ This module can also be converted to TorchScript using torch.jit.script for use
113
+ outside of Python.
114
+ """
115
+
116
+ def __init__(self, num_vals: int) -> None:
117
+ """Initializes the image_ECF module.
118
+
119
+ The initialized module is designed to compute the ECF of a 3D image, discretized by sampling num_vals values.
120
+
121
+ Args:
122
+ num_vals: The number of values to discretize the ECF over.
123
+ """
124
+ super().__init__()
125
+ self.num_vals: int = int(num_vals)
126
+
127
+ @staticmethod
128
+ def cell_values_3D(arr: torch.Tensor) -> List[torch.Tensor]:
129
+ """
130
+ Creates a cubical complex with a function on its cells from a 3D tensor.
131
+ The structure of the cubical complex is ignored with only the function values on the cells
132
+ being recorded.
133
+
134
+ Args:
135
+ arr (torch.Tensor): A 3D tensor with values between 0 and 1.
136
+
137
+ Returns:
138
+ vertex_values (torch.Tensor): A 1D tensor containing the function values of each vertex.
139
+ edge_values (torch.Tensor): A 1D tensor containing the function values of each edge.
140
+ square_values (torch.Tensor): A 1D tensor containing the function values of each square.
141
+ cube_values (torch.Tensor): A 1D tensor containing the function values of each cube.
142
+ """
143
+ arr = arr.float()
144
+
145
+ vertex_values = arr.reshape(-1)
146
+
147
+ x_edge_values = torch.maximum(arr[1:, ...], arr[:-1, ...])
148
+ y_edge_values = torch.maximum(arr[:, 1:, :], arr[:, :-1, :])
149
+ z_edge_values = torch.maximum(arr[..., 1:], arr[..., :-1])
150
+ edge_values = torch.cat([
151
+ x_edge_values.reshape(-1),
152
+ y_edge_values.reshape(-1),
153
+ z_edge_values.reshape(-1)
154
+ ], dim=0)
155
+
156
+ x_square_values = torch.maximum(y_edge_values[..., 1:], y_edge_values[..., :-1])
157
+ y_square_values = torch.maximum(z_edge_values[1:, ...], z_edge_values[:-1, ...])
158
+ z_square_values = torch.maximum(x_edge_values[:, 1:, :], x_edge_values[:, :-1, :])
159
+ square_values = torch.cat([
160
+ x_square_values.reshape(-1),
161
+ y_square_values.reshape(-1),
162
+ z_square_values.reshape(-1)
163
+ ], dim=0)
164
+
165
+ cube_values = torch.maximum(x_square_values[1:, ...], x_square_values[:-1, ...]).reshape(-1)
166
+
167
+ return [vertex_values, edge_values, square_values, cube_values]
168
+
169
+
170
+ def forward(self, img_arr: torch.Tensor) -> torch.Tensor:
171
+ """
172
+ Calculates a discretization of the ECF of a 3D image.
173
+
174
+ Args:
175
+ img_arr (torch.Tensor): A 3D tensor with values between 0 and 1.
176
+
177
+ Returns:
178
+ ecf (torch.Tensor): A 1D tensor of shape (self.num_vals) containing the sublevel set ECF.
179
+ """
180
+
181
+ device = img_arr.device
182
+ n = self.num_vals
183
+ vertex_values, edge_values, square_values, cube_values = self.cell_values_3D(img_arr)
184
+
185
+ vertex_indices = torch.ceil(vertex_values * (n-1)).long()
186
+ edge_indices = torch.ceil(edge_values * (n-1)).long()
187
+ square_indices = torch.ceil(square_values * (n-1)).long()
188
+ cube_indices = torch.ceil(cube_values * (n-1)).long()
189
+
190
+ diff_ecf = torch.zeros(n, dtype=torch.int32, device=device)
191
+
192
+ # Add the contribution of the vertices
193
+ diff_ecf.scatter_add_(
194
+ 0,
195
+ vertex_indices,
196
+ torch.ones_like(vertex_indices, dtype=torch.int32)
197
+ )
198
+
199
+ # Add the contribution of the edges
200
+ diff_ecf.scatter_add_(
201
+ 0,
202
+ edge_indices,
203
+ -1 * torch.ones_like(edge_indices, dtype=torch.int32)
204
+ )
205
+
206
+ # Add the contribution of the squares
207
+ diff_ecf.scatter_add_(
208
+ 0,
209
+ square_indices,
210
+ torch.ones_like(square_indices, dtype=torch.int32)
211
+ )
212
+
213
+ # Add the contribution of the cubes
214
+ diff_ecf.scatter_add_(
215
+ 0,
216
+ cube_indices,
217
+ -1 * torch.ones_like(cube_indices, dtype=torch.int32)
218
+ )
219
+
220
+ return torch.cumsum(diff_ecf, dim=0)
File without changes
@@ -0,0 +1,242 @@
1
+ from typing import Optional
2
+ import torch
3
+ import torchvision.transforms as transforms
4
+ from pyect import Complex
5
+ from PIL import Image
6
+
7
+
8
+ def image_to_grayscale_tensor(image_path: str, device: torch.device) -> torch.Tensor:
9
+ # Open the image using PIL
10
+ image = Image.open(image_path)
11
+ # Convert the image to grayscale (mode 'L')
12
+ grayscale_image = image.convert("L")
13
+ # Convert the grayscale image to a tensor with values in [0,1]
14
+ tensor = transforms.ToTensor()(grayscale_image).squeeze(dim=0)
15
+ # The resulting tensor will have shape (H, W)
16
+ return tensor.to(device)
17
+
18
+
19
+ def weighted_freudenthal(
20
+ img_arr: torch.Tensor, device: Optional[torch.device] = None
21
+ ) -> Complex:
22
+ """
23
+ Creates the weighted Freudenthal complex of an image array using a max function extension.
24
+ Discards edges and triangles that have a vertex with a zero weight.
25
+ By default, the device of the input tensor is used unless a different device is specified.
26
+
27
+ The vertices are a (h*w, 2) tensor with recentered pixel coordinates.
28
+ The vertex weights are a (h*w,) tensor containing the pixel intensities.
29
+ The edges are a (num_valid_edges, 2) tensor of vertex indices.
30
+ The edge weights are a (num_valid_edges,) tensor with the maximum weight on the edge.
31
+ The triangles are a (num_valid_triangles, 3) tensor of vertex indices.
32
+ The triangle weights are a (num_valid_triangles,) tensor with the maximum weight on the triangle.
33
+
34
+ Args:
35
+ img_arr (torch.Tensor): A grayscale image of shape (h, w).
36
+ device (torch.device, optional): The device to create tensors on.
37
+ If None, the device of the input tensor is used.
38
+
39
+ Returns:
40
+ Complex: A complex containing the weighted vertices, weighted edges, and weighted triangles.
41
+ """
42
+
43
+ device = img_arr.device if device is None else device
44
+ img_arr = img_arr.float().to(device)
45
+ h, w = img_arr.shape
46
+
47
+ # Create a mask of the nonzero pixels
48
+ img_mask = img_arr != 0
49
+
50
+ # Indices of nonzero pixels (vertices)
51
+ nonzero_vertices = torch.nonzero(img_mask, as_tuple=True)
52
+
53
+ # Enumerate the nonzero vertices in the index array with all other values set to 0
54
+ vertex_numbers = torch.zeros_like(img_arr, dtype=torch.int64, device=device)
55
+ vertex_numbers[nonzero_vertices] = torch.arange(
56
+ nonzero_vertices[0].size(0), dtype=torch.int64, device=device
57
+ )
58
+
59
+ # Construct the vertex coords and weights
60
+ vertex_coords = torch.stack([
61
+ nonzero_vertices[1] - (w - 1) / 2.0,
62
+ (h - 1) / 2.0 - nonzero_vertices[0]
63
+ ], dim=1)
64
+ vertex_weights = img_arr[nonzero_vertices]
65
+ vertices = (vertex_coords, vertex_weights)
66
+
67
+ ### Horizontal Edges
68
+ # Remove the first and last columns of img_mask and check where the resulting arrays are both nonzero
69
+ horizontal_edge_mask = img_mask[:, :-1] & img_mask[:, 1:]
70
+ horizontal_edge_indices = torch.nonzero(horizontal_edge_mask, as_tuple=True)
71
+
72
+ # Get the vertex numbers of the endpoints of each horizontal edge
73
+ horizontal_edge_vertices = torch.stack([
74
+ vertex_numbers[horizontal_edge_indices],
75
+ vertex_numbers[:, 1:][horizontal_edge_indices]
76
+ ], dim=1)
77
+ horizontal_edge_weights = vertex_weights[horizontal_edge_vertices].amax(dim=1)
78
+
79
+ ### Vertical Edges
80
+ # Remove the first and last rows of img_mask and check where the resulting arrays are both nonzero
81
+ vertical_edge_mask = img_mask[:-1, :] & img_mask[1:, :]
82
+ vertical_edge_indices = torch.nonzero(vertical_edge_mask, as_tuple=True)
83
+
84
+ # Get the vertex numbers of the endpoints of each vertical edge
85
+ vertical_edge_vertices = torch.stack([
86
+ vertex_numbers[vertical_edge_indices],
87
+ vertex_numbers[1:, :][vertical_edge_indices]
88
+ ], dim=1)
89
+ vertical_edge_weights = vertex_weights[vertical_edge_vertices].amax(dim=1)
90
+
91
+ ### Diagonal Edges
92
+ diagonal_edge_mask = img_mask[:-1, :-1] & img_mask[1:, 1:]
93
+ diagonal_edge_indices = torch.nonzero(diagonal_edge_mask, as_tuple=True)
94
+ diagonal_edge_vertices = torch.stack([
95
+ vertex_numbers[diagonal_edge_indices],
96
+ vertex_numbers[1:, 1:][diagonal_edge_indices]
97
+ ], dim=1)
98
+ diagonal_edge_weights = vertex_weights[diagonal_edge_vertices].amax(dim=1)
99
+
100
+ # Concatenate the horizontal, vertical, and diagonal edges
101
+ edge_vertices = torch.cat([
102
+ horizontal_edge_vertices,
103
+ vertical_edge_vertices,
104
+ diagonal_edge_vertices
105
+ ], dim=0)
106
+ edge_weights = torch.cat([
107
+ horizontal_edge_weights,
108
+ vertical_edge_weights,
109
+ diagonal_edge_weights
110
+ ], dim=0)
111
+ edges = (edge_vertices, edge_weights)
112
+
113
+ ### Upper Triangles
114
+ upper_triangle_mask = img_mask[:-1, :-1] & img_mask[:-1, 1:] & img_mask[1:, 1:]
115
+ upper_triangle_indices = torch.nonzero(upper_triangle_mask, as_tuple=True)
116
+ upper_triangle_vertices = torch.stack([
117
+ vertex_numbers[upper_triangle_indices],
118
+ vertex_numbers[:, 1:][upper_triangle_indices],
119
+ vertex_numbers[1:, 1:][upper_triangle_indices]
120
+ ], dim=1)
121
+ upper_triangle_weights = vertex_weights[upper_triangle_vertices].amax(dim=1)
122
+
123
+ ### Lower Triangles
124
+ lower_triangle_mask = img_mask[:-1, :-1] & img_mask[1:, :-1] & img_mask[1:, 1:]
125
+ lower_triangle_indices = torch.nonzero(lower_triangle_mask, as_tuple=True)
126
+ lower_triangle_vertices = torch.stack([
127
+ vertex_numbers[lower_triangle_indices],
128
+ vertex_numbers[1:, :][lower_triangle_indices],
129
+ vertex_numbers[1:, 1:][lower_triangle_indices]
130
+ ], dim=1)
131
+ lower_triangle_weights = vertex_weights[lower_triangle_vertices].amax(dim=1)
132
+
133
+ ### Concatenate the upper and lower triangles
134
+ triangle_vertices = torch.cat([
135
+ upper_triangle_vertices,
136
+ lower_triangle_vertices
137
+ ], dim=0)
138
+ triangle_weights = torch.cat([
139
+ upper_triangle_weights,
140
+ lower_triangle_weights
141
+ ], dim=0)
142
+ triangles = (triangle_vertices, triangle_weights)
143
+
144
+ return Complex(vertices, edges, triangles, device=device)
145
+
146
+
147
+ def weighted_cubical(
148
+ img_arr: torch.Tensor, device: Optional[torch.device] = None
149
+ ) -> Complex:
150
+ """
151
+ Creates the weighted cubical complex of an image array.
152
+ Discards edges and squares that have a vertex with zero weight.
153
+
154
+ The vertices are a (h*w, 2) tensor with recentered pixel coordinates.
155
+ The vertex weights are a (h*w,) tensor containing the pixel intensities.
156
+ The edges are a (num_valid_edges, 2) tensor of vertex indices.
157
+ The edge weights are a (num_valid_edges,) tensor with the maximum weight on the edge.
158
+ The squares are a (num_valid_squares, 4) tensor of vertex indices.
159
+ The square weights are a (num_valid_squares,) tensor with the maximum weight on
160
+ the square.
161
+
162
+ Args:
163
+ img_arr (torch.Tensor): A grayscale image of shape (h, w).
164
+ device (torch.device, optional): The device to create tensors on.
165
+ If None, the device of the input tensor is used.
166
+
167
+ Returns:
168
+ Complex: A complex containing the weighted vertices, weighted edges, and weighted squares.
169
+ """
170
+
171
+ device = img_arr.device if device is None else device
172
+ img_arr = img_arr.float().to(device)
173
+ h, w = img_arr.shape
174
+
175
+ # Create a mask of the nonzero pixels
176
+ img_mask = img_arr != 0
177
+
178
+ # Indices of nonzero pixels (vertices)
179
+ nonzero_vertices = torch.nonzero(img_mask, as_tuple=True)
180
+
181
+ # Create an array enumerating the nonzero vertices with all other values 0
182
+ vertex_numbers = torch.zeros_like(img_arr, dtype=torch.int64, device=device)
183
+ vertex_numbers[nonzero_vertices] = torch.arange(
184
+ nonzero_vertices[0].size(0), dtype=torch.int64, device=device
185
+ )
186
+
187
+ # Construct the vertex coords and weights
188
+ vertex_coords = torch.stack([
189
+ nonzero_vertices[1] - (w - 1) / 2.0,
190
+ (h - 1) / 2.0 - nonzero_vertices[0]
191
+ ], dim=1)
192
+ vertex_weights = img_arr[nonzero_vertices]
193
+ vertices = (vertex_coords, vertex_weights)
194
+
195
+ ### Horizontal Edges
196
+ # Remove the first and last columns of img_mask and check where the resulting arrays are both nonzero
197
+ horizontal_edge_mask = img_mask[:, :-1] & img_mask[:, 1:]
198
+ horizontal_edge_indices = torch.nonzero(horizontal_edge_mask, as_tuple=True)
199
+
200
+ # Get the vertex numbers of the endpoints of each horizontal edge
201
+ horizontal_edge_vertices = torch.stack([
202
+ vertex_numbers[horizontal_edge_indices],
203
+ vertex_numbers[:, 1:][horizontal_edge_indices]
204
+ ], dim=1)
205
+ horizontal_edge_weights = vertex_weights[horizontal_edge_vertices].amax(dim=1)
206
+
207
+ ### Vertical Edges
208
+ # Remove the first and last rows of img_mask and check where the resulting arrays are both nonzero
209
+ vertical_edge_mask = img_mask[:-1, :] & img_mask[1:, :]
210
+ vertical_edge_indices = torch.nonzero(vertical_edge_mask, as_tuple=True)
211
+
212
+ # Get the vertex numbers of the endpoints of each vertical edge
213
+ vertical_edge_vertices = torch.stack([
214
+ vertex_numbers[vertical_edge_indices],
215
+ vertex_numbers[1:, :][vertical_edge_indices]
216
+ ], dim=1)
217
+ vertical_edge_weights = vertex_weights[vertical_edge_vertices].amax(dim=1)
218
+
219
+ # Concatenate the horizontal and vertical edges
220
+ edge_vertices = torch.cat([
221
+ horizontal_edge_vertices,
222
+ vertical_edge_vertices
223
+ ], dim=0)
224
+ edge_weights = torch.cat([
225
+ horizontal_edge_weights,
226
+ vertical_edge_weights
227
+ ], dim=0)
228
+ edges = (edge_vertices, edge_weights)
229
+
230
+ ###Squares
231
+ square_mask = horizontal_edge_mask[:-1, :] & horizontal_edge_mask[1:, :]
232
+ square_indices = torch.nonzero(square_mask, as_tuple=True)
233
+ square_vertices = torch.stack([
234
+ vertex_numbers[square_indices],
235
+ vertex_numbers[1:, :][square_indices],
236
+ vertex_numbers[:, 1:][square_indices],
237
+ vertex_numbers[1:, 1:][square_indices]
238
+ ], dim=1)
239
+ square_weights = vertex_weights[square_vertices].amax(dim=1)
240
+ squares = (square_vertices, square_weights)
241
+
242
+ return Complex(vertices, edges, squares, n_type="cubical", device=device)
@@ -0,0 +1,204 @@
1
+ """Tools for working with simplicial complexes.
2
+
3
+ The Complex class is a collection of simplices, each of which is represented by a
4
+ tensor of coordinates and a tensor of weights.
5
+ """
6
+
7
+ from typing import Tuple, Optional
8
+
9
+ import torch
10
+ import warnings
11
+ import numpy.typing as npt
12
+
13
+ from .dtypes import COORDS_DTYPE, INDICES_DTYPE, WEIGHTS_DTYPE
14
+
15
+
16
+ class Complex:
17
+ """A simplicial complex of arbitrary dimension.
18
+
19
+ The representation is as a collection of simplices (or cubical cells) using tensors.
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ *args: Tuple[torch.Tensor, torch.Tensor],
25
+ vertex_dtype: torch.dtype = COORDS_DTYPE,
26
+ index_dtype: torch.dtype = INDICES_DTYPE,
27
+ weights_dtype: torch.dtype = WEIGHTS_DTYPE,
28
+ device: Optional[torch.device] = None,
29
+ n_type: str = "simplicial",
30
+ ) -> None:
31
+ """Initializes a complex.
32
+
33
+ All tensors are cast to the given types.
34
+
35
+ Args:
36
+ *args: A variable number of tuples, each containing the simplices of a given
37
+ dimension. Each tuple should contain two tensors.
38
+ The first tensor contains the coordinates of the simplices
39
+ The second tensor contains the weights of the simplices.
40
+
41
+ The first tuple should contain the vertices of the complex, and
42
+ therefore must be a tensor of shape [num_vertices, d].
43
+
44
+ Any following tuples should contain indices into the vertices tensor,
45
+ and therefore must be a tensor of shape [num_simplices, k], where k+1 is the
46
+ dimension of the simplex.
47
+
48
+ vertex_dtype: The data type to use for the vertex coordinates.
49
+ index_dtype: The data type to use for the simplex indices.
50
+ weights_dtype: The data type to use for the simplex weights.
51
+ device: The device to use for the tensors.
52
+ n_type: The type of complex. Currently only "simplicial" and "cubical"
53
+ are supported.
54
+ """
55
+ # Verify the dimensions of the simplices, and raise a UserError if
56
+ # there is a mismatch.
57
+ self._validate_dimensions(*args, n_type=n_type)
58
+
59
+ # Call .to on each tensor to cast to the given type and device.
60
+ types = [vertex_dtype] + [index_dtype] * (len(args) - 1)
61
+ self.dimensions = tuple(
62
+ (
63
+ (
64
+ coords.to(dtype=types[dim], device=device),
65
+ weights.to(dtype=weights_dtype, device=device),
66
+ )
67
+ )
68
+ for dim, (coords, weights) in enumerate(args)
69
+ )
70
+ self.n_type = n_type
71
+
72
+ @staticmethod
73
+ def from_numpy(
74
+ *args: Tuple[npt.NDArray, npt.NDArray],
75
+ vertex_dtype: torch.dtype = COORDS_DTYPE,
76
+ index_dtype: torch.dtype = INDICES_DTYPE,
77
+ weights_dtype: torch.dtype = WEIGHTS_DTYPE,
78
+ device: Optional[torch.device] = None,
79
+ n_type: str = "simplicial",
80
+ ) -> "Complex":
81
+ """Initializes a simplicial complex from numpy arrays.
82
+
83
+ Args:
84
+ *args: A variable number of tuples, each containing the simplices of a given
85
+ dimension. Each tuple should contain two numpy arrays.
86
+ The first array contains the coordinates of the simplices
87
+ The second array contains the weights of the simplices.
88
+
89
+ The first tuple should contain the vertices of the complex, and
90
+ therefore must be a tensor of shape [num_vertices, d].
91
+
92
+ Any following tuples should contain indices into the vertices tensor,
93
+ and therefore must be a tensor of shape [num_simplices, k], where k+1 is the
94
+ dimension of the simplex.
95
+
96
+ vertex_dtype: The data type to use for the vertex coordinates.
97
+ index_dtype: The data type to use for the simplex indices.
98
+ weights_dtype: The data type to use for the simplex weights.
99
+ device:
100
+ The device to use for the tensors.
101
+ n_type: The type of the simplicial complex. Currently only "simplicial" and "cubical"
102
+ are supported.
103
+
104
+ """
105
+ if device is None:
106
+ device = (
107
+ torch.device("cuda")
108
+ if torch.cuda.is_available()
109
+ else torch.device("cpu")
110
+ )
111
+
112
+ typematch = [vertex_dtype] + [index_dtype] * (len(args) - 1)
113
+ dimensions = tuple(
114
+ (
115
+ torch.as_tensor(coords, device=device, dtype=typematch[i]),
116
+ torch.as_tensor(weights, device=device, dtype=weights_dtype),
117
+ )
118
+ for i, (coords, weights) in enumerate(args)
119
+ )
120
+ return Complex(*dimensions, device=device, n_type=n_type)
121
+
122
+ def to(self, device: torch.device) -> "Complex":
123
+ """Moves the complex to the given device."""
124
+ return Complex(*self.dimensions, device=device, n_type=self.n_type)
125
+
126
+ def __getitem__(self, dim: int) -> Tuple[torch.Tensor, torch.Tensor]:
127
+ """Returns the simplices of the given dimension."""
128
+ return self.dimensions[dim]
129
+
130
+ def get_coords(self, dim: int) -> torch.Tensor:
131
+ """Returns the coordinates of the simplices of the given dimension."""
132
+ return self.dimensions[dim][0]
133
+
134
+ def get_weights(self, dim: int) -> torch.Tensor:
135
+ """Returns the weights of the simplices of the given dimension."""
136
+ return self.dimensions[dim][1]
137
+
138
+ def top_dim(self) -> int:
139
+ """Returns the top dimension of the complex."""
140
+ return len(self) - 1
141
+
142
+ def __len__(self) -> int:
143
+ """Returns the number of dimensions in the complex."""
144
+ return len(self.dimensions)
145
+
146
+ def space_dim(self) -> int:
147
+ """Returns the dimension of the space the complex is embedded in."""
148
+ return self.dimensions[0][0].shape[1]
149
+
150
+ def center_(self) -> "Complex":
151
+ """
152
+ Re-center the complex in-place so that the average vertex coordinate is at the origin.
153
+ """
154
+ if len(self.dimensions) == 0:
155
+ return self
156
+
157
+ v_coords, v_weights = self.dimensions[0]
158
+ if v_coords.numel() == 0:
159
+ return self
160
+
161
+ center = v_coords.mean(dim=0)
162
+ new_v_coords = (v_coords - center).contiguous()
163
+
164
+ dims: list[Tuple[torch.Tensor, torch.Tensor]] = list(self.dimensions)
165
+ dims[0] = (new_v_coords, v_weights)
166
+ self.dimensions = tuple(dims)
167
+ return self
168
+
169
+ @staticmethod
170
+ def _validate_dimensions(
171
+ *args: Tuple[torch.Tensor, torch.Tensor], n_type: str
172
+ ) -> None:
173
+ for dim, simplex_list in enumerate(args):
174
+ if simplex_list[0].dim() != 2:
175
+ raise ValueError(
176
+ f"Dimension {dim} simplices must be a 2d tensor."
177
+ + f" Got {simplex_list[0].dim()} dimensions."
178
+ )
179
+ if simplex_list[1].dim() != 1:
180
+ raise ValueError(
181
+ f"Dimension {dim} weights must be a 1d tensor."
182
+ + f" Got {simplex_list[1].dim()} dimensions."
183
+ )
184
+ if simplex_list[0].shape[0] != simplex_list[1].shape[0]:
185
+ raise ValueError(
186
+ f"Dimension {dim} coordinates and weights must have the same number of simplices."
187
+ + f" Got {simplex_list[0].shape[0]} simplices and {simplex_list[1].shape[0]} weights."
188
+ )
189
+
190
+ if dim > 0: # simplices, k > 0
191
+ if n_type == "simplicial":
192
+ if simplex_list[0].shape[1] != dim + 1:
193
+ raise ValueError(
194
+ f"Dimension {dim} simplices must have {dim + 1} columns."
195
+ + f" Got {simplex_list[0].shape[1]} columns."
196
+ )
197
+ elif n_type == "cubical":
198
+ if simplex_list[0].shape[1] != 2 ** dim:
199
+ raise ValueError(
200
+ f"Dimension {dim} simplices must have {2 ** dim} columns."
201
+ + f" Got {simplex_list[0].shape[1]} columns."
202
+ )
203
+ else: # warn that validation not implementod for n_type, but no error
204
+ warnings.warn(f"Validation not implemented for n_type {n_type}. Proceed with caution.")
pyect/wecfs.py ADDED
@@ -0,0 +1,65 @@
1
+ """ For computing the WECFs of lower-star filtrations of
2
+ weighted simplicial/cubical complex with respect to a set of filter functions."""
3
+
4
+ import torch
5
+ from typing import List, Tuple
6
+
7
+ def compute_wecfs(
8
+ complex_data: List[Tuple[torch.Tensor, torch.Tensor]],
9
+ num_vals: int
10
+ ) -> torch.Tensor:
11
+ """Calculates a discretization of the WECFs of a weighted complex with respect to a set of filter functions.
12
+
13
+ Args:
14
+ complex_data: A weighted simplicial or cubical complex with a collection of filter functions,
15
+ represented as a list of pairs of tensors.
16
+ complex_data[0] = (filters, v_weights):
17
+ filters (torch.Tensor): A tensor of shape (k_0, m) where k_0 is the
18
+ number of vertices and m is the number of filter functions.
19
+ Each column contains the values of a filter function on the vertices.
20
+
21
+ v_weights (torch.Tensor): A tensor of shape (k_0). Values are the weights of the vertices.
22
+
23
+ for i > 0:
24
+ complex_data[i] = (simp_verts, simp_weights):
25
+ simp_verts (torch.Tensor): A tensor of shape (k_i, i+1) where k_i is the number of i-simplices.
26
+ Rows are the vertex sets of the i-simplices.
27
+
28
+ simp_weights (torch.Tensor): A tensor of shape (k_i). Values are the weights of the i-simplices.
29
+
30
+ Returns:
31
+ wecfs (torch.Tensor): A 2d tensor of shape (m, num_vals)
32
+ containing the WECFs.
33
+ """
34
+
35
+ filters = complex_data[0][0].float()
36
+ m = filters.size(dim=1)
37
+ device = filters.device
38
+ v_weights = complex_data[0][1].to(device=device, dtype=torch.float32)
39
+
40
+ expanded_v_weights = v_weights.unsqueeze(0).expand(m, -1) # Expand to shape (m, k_0)
41
+
42
+ # Map the values of the filter functions to indices in range(num_vals)
43
+ max_val = filters.abs().amax()
44
+ v_indices = torch.ceil(
45
+ (num_vals - 1) * (max_val + filters) / (2.0 * max_val)
46
+ ).clamp(0, num_vals-1).long()
47
+
48
+ # Initialize the differentiated WECFs
49
+ diff_wecfs = torch.zeros((m, num_vals), dtype=torch.float32, device=device)
50
+
51
+ # Add the contribution of the vertices to the differentiated WECFs
52
+ diff_wecfs.scatter_add_(1, v_indices.T, expanded_v_weights)
53
+
54
+ for i in range(1, len(complex_data)):
55
+ simp_verts = complex_data[i][0].to(device=device, dtype=torch.long)
56
+ simp_weights = complex_data[i][1].to(device=device, dtype=torch.float32)
57
+
58
+ expanded_simp_weights = (-1) ** i * simp_weights.unsqueeze(0).expand(m, -1)
59
+
60
+ simp_indices = v_indices[simp_verts]
61
+ max_simp_indices = torch.amax(simp_indices, dim=1)
62
+
63
+ diff_wecfs.scatter_add_(1, max_simp_indices.T, expanded_simp_weights)
64
+
65
+ return torch.cumsum(diff_wecfs, dim=1)
pyect/wect.py ADDED
@@ -0,0 +1,130 @@
1
+ """For computing the WECT of a weighted geometric simplicial/cubical complex embedded in R^n."""
2
+
3
+ import torch
4
+ from typing import List, Tuple
5
+
6
+
7
+ class WECT(torch.nn.Module):
8
+ """A torch module for computing the Weighted Euler Characteristic Transform (WECT) of a simplicial complex discretized over a grid.
9
+
10
+ This module may be used just for computing the WECT, or used as a layer in a neural network.
11
+ Internally, the module stores the directions and number of heights used for sampling, so repeated forward calls
12
+ do not require these parameters to be passed in, and allow streamlined loading/saving of the module for consistent
13
+ computation.
14
+
15
+ This module can also be converted to TorchScript using torch.jit.script for use
16
+ outside of Python.
17
+ """
18
+
19
+ def __init__(self, dirs: torch.Tensor, num_heights: int) -> None:
20
+ """Initializes the WECT module.
21
+
22
+ The initialized module is designed to compute the WECT of a simplicial complex
23
+ embedded in R^[dirs.shape[1]], using dirs.shape[0] directions for sampling.
24
+ The discretization of the WECT is parameterized by num_heights distinct height values.
25
+
26
+ Args:
27
+ dirs: An (d x n) tensor of directions to use for sampling.
28
+ num_heights: A constant tensor, with the number of distinct height
29
+ values to round to as an integer
30
+ """
31
+ super().__init__()
32
+ dirs = torch.nn.functional.normalize(dirs, p=2, dim=1, eps=1e-12)
33
+ self.register_buffer("dirs", dirs)
34
+ self.num_heights: int = int(num_heights)
35
+
36
+ def _vertex_indices(
37
+ self,
38
+ vertex_coords: torch.Tensor,
39
+ ) -> torch.Tensor:
40
+ """Calculates the height values of each vertex and converts them to an index in range(num_heights).
41
+
42
+ Args:
43
+ vertex_coords (torch.Tensor): A tensor of shape (k_0, n) with rows representing the coordinates of the vertices.
44
+
45
+ Returns:
46
+ torch.Tensor: A tensor of shape (k_0, d) with the height indices of each vertex in each direction.
47
+ """
48
+
49
+ v_norms = torch.norm(vertex_coords, dim=1)
50
+ max_height = torch.amax(v_norms)
51
+ v_heights = torch.matmul(vertex_coords, self.dirs.T)
52
+
53
+ # The case where all vertices are at the origin
54
+ if max_height.item() == 0.0:
55
+ return torch.zeros((v_heights.size(0), self.dirs.size(0)), dtype=torch.long, device=self.dirs.device)
56
+
57
+ v_indices = torch.ceil(
58
+ (self.num_heights - 1) * (max_height + v_heights) / (2.0 * max_height)
59
+ ).clamp(0, self.num_heights - 1).long()
60
+
61
+ return v_indices
62
+
63
+ def forward(
64
+ self,
65
+ complex_data: List[Tuple[torch.Tensor, torch.Tensor]],
66
+ ) -> torch.Tensor:
67
+ """Calculates a discretization of the WECT of a complex embedded in n-dimensional space.
68
+
69
+ Args:
70
+ complex_data: A weighted simplicial or cubical complex, represented as a list of pairs of tensors.
71
+ complex_data[0] = (v_coords, v_weights):
72
+ v_coords (torch.Tensor): A tensor of shape (k_0, n) where k_0 is the number of vertices.
73
+ Rows are the coordinates of the vertices.
74
+
75
+ v_weights (torch.Tensor): A tensor of shape (k_0). Values are the weights of the vertices.
76
+
77
+ for i > 0:
78
+ complex_data[i] = (simp_verts, simp_weights):
79
+ simp_verts (torch.Tensor): A tensor of shape (k_i, i+1) where k_i is the number of i-simplices.
80
+ Rows are the vertex sets of the i-simplices.
81
+
82
+ simp_weights (torch.Tensor): A tensor of shape (k_i). Values are the weights of the i-simplices.
83
+
84
+ Returns:
85
+ wect (torch.Tensor): A 2d tensor of shape (self.dirs.shape[0], self.num_heights)
86
+ containing the WECT.
87
+ """
88
+
89
+ d = self.dirs.size(dim=0)
90
+ h = self.num_heights
91
+
92
+ if h <= 0:
93
+ raise ValueError("num_heights must be positive.")
94
+
95
+ device = self.dirs.device
96
+ v_coords = complex_data[0][0].to(device=device, dtype=torch.float32)
97
+ v_weights = complex_data[0][1].to(device=device, dtype=torch.float32)
98
+
99
+ # Check for empty inputs
100
+ if v_coords.size(0) == 0:
101
+ return torch.zeros((d, h), dtype=torch.float32, device=device)
102
+
103
+ expanded_v_weights = v_weights.unsqueeze(0).expand(
104
+ d, -1
105
+ ) # Expand to shape (d, k_0)
106
+
107
+ # Initialize the differentiated WECT
108
+ diff_wect = torch.zeros((d, h), dtype=torch.float32, device=device)
109
+
110
+ # Compute the height index of each vertex
111
+ v_indices = self._vertex_indices(v_coords)
112
+
113
+ # Add the contribution of the vertices to the differentiated WECT
114
+ diff_wect.scatter_add_(1, v_indices.T, expanded_v_weights)
115
+
116
+ for i in range(1, len(complex_data)):
117
+ simp_verts = complex_data[i][0].to(device=device, dtype=torch.long)
118
+ simp_weights = complex_data[i][1].to(device=device, dtype=torch.float32)
119
+
120
+ # Expand to shape (d, k_i)
121
+ expanded_simp_weights = (-1) ** i * simp_weights.unsqueeze(0).expand(d, -1)
122
+
123
+ # Compute the maximum index for each simplex's vertices
124
+ simp_indices = v_indices[simp_verts]
125
+ max_simp_indices = torch.amax(simp_indices, dim=1)
126
+
127
+ # Add the contribution of the i-simplices to the differentiated WECT
128
+ diff_wect.scatter_add_(1, max_simp_indices.T, expanded_simp_weights)
129
+
130
+ return torch.cumsum(diff_wect, dim=1)
@@ -0,0 +1,80 @@
1
+ Metadata-Version: 2.4
2
+ Name: pyect
3
+ Version: 0.1.0
4
+ Summary: Generalized computation of the WECT using PyTorch.
5
+ Home-page: https://github.com/compTAG/pyECT
6
+ Author: Alex McCleary, Eli Quist, Jack Ruder, Jacob Sriraman
7
+ Author-email: eli.quist@student.montana.edu
8
+ License: MIT
9
+ Classifier: Programming Language :: Python :: 3
10
+ Classifier: License :: OSI Approved :: MIT License
11
+ Classifier: Operating System :: OS Independent
12
+ Requires-Python: >=3.8
13
+ Description-Content-Type: text/markdown
14
+ License-File: LICENSE
15
+ Requires-Dist: torch
16
+ Requires-Dist: Pillow
17
+ Dynamic: author
18
+ Dynamic: author-email
19
+ Dynamic: classifier
20
+ Dynamic: description
21
+ Dynamic: description-content-type
22
+ Dynamic: home-page
23
+ Dynamic: license
24
+ Dynamic: license-file
25
+ Dynamic: requires-dist
26
+ Dynamic: requires-python
27
+ Dynamic: summary
28
+
29
+ # pyECT
30
+
31
+ The Weighted Euler Characteristic Transform (WECT) is a mathematical tool
32
+ used to analyze and summarize geometric and topological features of data.
33
+ This package provides an efficient and simple implementation of the WECT using
34
+ PyTorch.
35
+
36
+ This codebase accompanies the following paper (and should be cited if you use
37
+ this package):
38
+
39
+ ```
40
+ TODO: Add Citation
41
+ ```
42
+
43
+ ## Installation
44
+
45
+ To install `pyECT`, use pip:
46
+
47
+ ```bash
48
+ pip install pyect
49
+ ```
50
+
51
+ ## Usage
52
+
53
+ Here's a simple example of how to use `pyECT`:
54
+
55
+ ```python
56
+ from pyect import WECT
57
+
58
+ # Example data and weight function
59
+ data = [...] # Replace with your data
60
+ weight_function = lambda x: x**2 # Replace with your weight function
61
+
62
+ # Compute the WECT
63
+ wect = WECT(data, weight_function)
64
+ result = wect.compute()
65
+
66
+ print("WECT result:", result)
67
+ ```
68
+
69
+ For more detailed examples, please see the `/examples` directory.
70
+
71
+ ## Contributing
72
+
73
+ Contributions are welcome! If you'd like to contribute, please fork the
74
+ repository and submit a pull request. For major changes, please open an issue
75
+ first to discuss what you'd like to change.
76
+
77
+ ## License
78
+
79
+ This project is licensed under the MIT License. See the [LICENSE](LICENSE)
80
+ file for details.
@@ -0,0 +1,14 @@
1
+ pyect/__init__.py,sha256=roSWMLV7yFltMExesaIqRyEZANwYaKuzBJOPlDuQa7E,302
2
+ pyect/directions.py,sha256=XmCyfc6GrRHmgeeeRVVTCJs2QvE4li0pU_pc98TDek0,890
3
+ pyect/dtypes.py,sha256=_GxQ8cm3UkRo-CU3v--4D_DFd0NhTIZl_u0yuIADMIY,144
4
+ pyect/image_ecf.py,sha256=2C_ga6sgMOO8XAua8tjbsFXTelmwj5arIN_2sCVwgc4,8175
5
+ pyect/tensor_complex.py,sha256=R_lgSOmTkK5PV8Q5BZF7MO73PU6ROGDew5mWwftDcE0,8345
6
+ pyect/wecfs.py,sha256=C__WLyuELs2wmhLIK2aZ-suNffotUmWyL_Q1S7_SYJA,2771
7
+ pyect/wect.py,sha256=g1wXsnkUx3m7NN9n3Tu6eI7hFWxJC9w_1htO0p98lAg,5560
8
+ pyect/preprocessing/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
+ pyect/preprocessing/image_processing.py,sha256=hMWoZgytBruwdM6Op4HUWlbqoYJKsT2rs_hp7U7lULM,9967
10
+ pyect-0.1.0.dist-info/licenses/LICENSE,sha256=pk1cMbtWYROTtwmnpV4fWVheiH-vqVq_kzP4mtkCywY,1144
11
+ pyect-0.1.0.dist-info/METADATA,sha256=Dhtj726KbcfZTCY59OJlBpehDkjF4CSxeMt4aR6pcnI,1964
12
+ pyect-0.1.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
13
+ pyect-0.1.0.dist-info/top_level.txt,sha256=0kDTwnLAArCjbb33d93Hg0IvoHhMKV9Fv9P-rQ-p6FY,6
14
+ pyect-0.1.0.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (80.9.0)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1,23 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Montana State University Computational Topology and
4
+ Geometry (CompTaG) Research Group.
5
+
6
+
7
+ Permission is hereby granted, free of charge, to any person obtaining a copy
8
+ of this software and associated documentation files (the "Software"), to deal
9
+ in the Software without restriction, including without limitation the rights
10
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
+ copies of the Software, and to permit persons to whom the Software is
12
+ furnished to do so, subject to the following conditions:
13
+
14
+ The above copyright notice and this permission notice shall be included in
15
+ all copies or substantial portions of the Software.
16
+
17
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23
+ THE SOFTWARE.
@@ -0,0 +1 @@
1
+ pyect