miblab-ssa 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
miblab_ssa/__init__.py ADDED
@@ -0,0 +1,14 @@
1
+ from .normalize import (
2
+ normalize_kidney_mask
3
+ )
4
+ from .ssa import (
5
+ features_from_dataset_zarr,
6
+ pca_from_features_zarr,
7
+ coefficients_from_features_zarr,
8
+ modes_from_pca_zarr,
9
+ )
10
+ from .metrics import (
11
+ hausdorff_matrix_zarr,
12
+ dice_matrix_zarr
13
+ )
14
+ from . import sdf_ft, sdf_cheby, lb, zernike
miblab_ssa/lb.py ADDED
@@ -0,0 +1,260 @@
1
+ import numpy as np
2
+ from skimage import measure
3
+ import trimesh
4
+ from scipy.sparse import coo_matrix, diags
5
+ from scipy.sparse.linalg import eigsh
6
+
7
+ # -------------------------------
8
+ # Helper: convert trimesh to mask
9
+ # -------------------------------
10
+ def mesh_to_mask(mesh, shape):
11
+ """
12
+ Rasterize mesh into 3D binary mask
13
+ """
14
+ mask = np.zeros(shape, dtype=bool)
15
+ # Use trimesh voxelization
16
+ vox = mesh.voxelized(pitch=1.0)
17
+ indices = vox.sparse_indices
18
+ mask[indices[:,0], indices[:,1], indices[:,2]] = True
19
+ return mask
20
+
21
+ # -------------------------------
22
+ # 1️⃣ Mask → Mesh
23
+ # -------------------------------
24
+ def mask_to_mesh(mask, spacing=(1.0,1.0,1.0)):
25
+ """
26
+ Convert 3D binary mask to triangular mesh.
27
+ """
28
+ verts, faces, normals, values = measure.marching_cubes(mask.astype(float), level=0.5, spacing=spacing)
29
+ mesh = trimesh.Trimesh(vertices=verts, faces=faces, process=False)
30
+ return mesh
31
+
32
+
33
+ def mask_to_mesh_fixed_vertices(mask: np.ndarray, spacing: np.ndarray, target_vertices: int = 5000) -> trimesh.Trimesh:
34
+ """
35
+ Convert a 3D binary mask to a mesh with a fixed number of vertices.
36
+
37
+ Parameters
38
+ ----------
39
+ center : bool
40
+ If True, center the mesh at the origin.
41
+ spacing : np.ndarray
42
+ Voxel size
43
+
44
+ Returns
45
+ -------
46
+ mesh_simplified : trimesh.Trimesh
47
+ Mesh object with approximately target_vertices vertices.
48
+ """
49
+ # Step 1: extract surface using marching cubes
50
+ verts, faces, normals, _ = measure.marching_cubes(mask.astype(float), level=0.5, spacing=spacing)
51
+
52
+ mesh = trimesh.Trimesh(vertices=verts, faces=faces, vertex_normals=normals, process=True)
53
+
54
+ # Step 2: simplify / resample to target number of vertices
55
+ # Needs testing
56
+ mesh_simplified = mesh.simplify_quadratic_decimation(target_vertices)
57
+
58
+ return mesh_simplified
59
+
60
+
61
+ # -------------------------------
62
+ # 2️⃣ Preprocessing for invariance (FIXED)
63
+ # -------------------------------
64
+ def preprocess_mesh(mesh):
65
+ """
66
+ Apply translation, scaling, and PCA alignment.
67
+ Returns processed mesh and preprocessing parameters for inverse mapping.
68
+ """
69
+ # Center
70
+ centroid = mesh.vertices.mean(axis=0)
71
+ mesh_c = mesh.copy()
72
+ mesh_c.vertices = mesh.vertices - centroid
73
+
74
+ # Scale
75
+ scale = np.sqrt((mesh_c.vertices**2).sum(axis=1).mean())
76
+ mesh_s = mesh_c.copy()
77
+ mesh_s.vertices = mesh_c.vertices / scale
78
+
79
+ # PCA alignment
80
+ cov = np.cov(mesh_s.vertices.T)
81
+ eigvals, eigvecs = np.linalg.eigh(cov)
82
+ idx = np.argsort(eigvals)[::-1]
83
+ eigvecs = eigvecs[:, idx]
84
+ mesh_aligned = mesh_s.copy()
85
+ mesh_aligned.vertices = mesh_s.vertices @ eigvecs
86
+
87
+ # Save parameters for inverse transformation
88
+ params = {"centroid": centroid, "scale": scale, "pca_eigvecs": eigvecs}
89
+ return mesh_aligned, params
90
+
91
+ def inverse_preprocess_mesh(vertices, params):
92
+ """
93
+ Map reconstructed vertices back to original coordinates.
94
+ """
95
+ v = vertices @ params["pca_eigvecs"].T # undo PCA
96
+ v = v * params["scale"] # undo scaling
97
+ v = v + params["centroid"] # undo translation
98
+ return v
99
+
100
+ # -------------------------------
101
+ # 3️⃣ Laplace-Beltrami Eigenfunctions
102
+ # -------------------------------
103
+ def cotangent_laplacian(mesh):
104
+ vertices = mesh.vertices
105
+ faces = mesh.faces
106
+
107
+ def cotangent(a, b, c):
108
+ ba = b - a
109
+ ca = c - a
110
+ cos_angle = np.dot(ba, ca)
111
+ sin_angle = np.linalg.norm(np.cross(ba, ca))
112
+ return cos_angle / (sin_angle + 1e-10)
113
+
114
+ I, J, V = [], [], []
115
+ n = len(vertices)
116
+ for face in faces:
117
+ i, j, k = face
118
+ vi, vj, vk = vertices[i], vertices[j], vertices[k]
119
+ cot_alpha = cotangent(vj, vi, vk)
120
+ cot_beta = cotangent(vk, vj, vi)
121
+ cot_gamma = cotangent(vi, vk, vj)
122
+ for (p, q, w) in [(i,j,cot_gamma),(j,i,cot_gamma),
123
+ (j,k,cot_alpha),(k,j,cot_alpha),
124
+ (k,i,cot_beta),(i,k,cot_beta)]:
125
+ I.append(p)
126
+ J.append(q)
127
+ V.append(w/2)
128
+
129
+ L = coo_matrix((V, (I, J)), shape=(n, n))
130
+ L = diags(L.sum(axis=1).A1) - L
131
+ return L
132
+
133
+ def lb_eigen_decomposition(mesh, k=50):
134
+ L = cotangent_laplacian(mesh)
135
+ M = diags(np.ones(mesh.vertices.shape[0]))
136
+ eigvals, eigvecs = eigsh(L, k=k, M=M, sigma=1e-8, which='LM')
137
+ return eigvals, eigvecs
138
+
139
+ def surface_to_coefficients(mesh, k=50):
140
+ eigvals, eigvecs = lb_eigen_decomposition(mesh, k=k)
141
+ coords = mesh.vertices
142
+ coeffs = eigvecs.T @ coords # shape (k,3)
143
+ return coeffs, eigvecs, eigvals
144
+
145
+
146
+ def rotationally_invariant_lb_coeffs(coeffs, eigvals, k=100):
147
+ """
148
+ Compute rotationally invariant Laplace–Beltrami spectral coefficients.
149
+
150
+ Parameters
151
+ ----------
152
+ mesh : trimesh.Trimesh or similar
153
+ Input surface mesh with vertices (N, 3)
154
+ k : int
155
+ Number of eigenmodes to use
156
+
157
+ Returns
158
+ -------
159
+ eigvals : (k,) array
160
+ Laplace–Beltrami eigenvalues
161
+ invariants : (k,) array
162
+ Rotationally invariant spectral coefficients
163
+ """
164
+ invariants = np.linalg.norm(coeffs, axis=1) # sqrt(sum over x,y,z)
165
+ invariants /= np.linalg.norm(invariants)
166
+
167
+ # Optional: normalize eigenvalues by first non-zero eigenvalue
168
+ eigvals = eigvals / eigvals[1] if eigvals[1] != 0 else eigvals
169
+
170
+ # Optionally drop the first eigenvalue (zero mode) from descriptor since it's trivial
171
+ eigvals = eigvals[1:] # length k-1
172
+ invariants = invariants[1:] # skip first mode as it may be trivial
173
+
174
+ descriptor = np.concatenate([eigvals[:k], invariants[:k]])
175
+ descriptor /= np.linalg.norm(descriptor) # normalize final vector
176
+
177
+ return invariants, eigvals
178
+
179
+
180
+ # def coefficients_to_surface(coeffs, eigvecs):
181
+ # reconstructed = eigvecs @ coeffs
182
+ # return reconstructed
183
+
184
+ def coefficients_to_surface(coeffs, eigvecs, threshold=None):
185
+ """
186
+ Reconstruct surface vertices from coefficients and eigenvectors.
187
+
188
+ Args:
189
+ coeffs (np.ndarray): shape (k, 3), coefficients from surface_to_coefficients
190
+ eigvecs (np.ndarray): shape (n, k), eigenvectors of Laplace-Beltrami
191
+ threshold (float, optional): percentage (0-100).
192
+ If given, only the top threshold% dominant modes (by coefficient norm)
193
+ are kept in the reconstruction.
194
+
195
+ Returns:
196
+ np.ndarray: reconstructed vertices, shape (n, 3)
197
+ """
198
+ if threshold is not None:
199
+ # Compute importance of each eigenfunction
200
+ norms = np.linalg.norm(coeffs, axis=1)
201
+ k = len(norms)
202
+
203
+ # How many to keep
204
+ keep = max(1, int(np.ceil(k * threshold / 100.0)))
205
+
206
+ # Select indices of the most important modes
207
+ idx_sorted = np.argsort(norms)[::-1]
208
+ idx_keep = idx_sorted[:keep]
209
+
210
+ # Zero out the others
211
+ coeffs_filtered = np.zeros_like(coeffs)
212
+ coeffs_filtered[idx_keep] = coeffs[idx_keep]
213
+
214
+ reconstructed = eigvecs @ coeffs_filtered
215
+ else:
216
+ reconstructed = eigvecs @ coeffs
217
+
218
+ return reconstructed
219
+
220
+
221
+ def pipeline(mask, k=50):
222
+ # mesh = mask_to_mesh(mask)
223
+ # Fixed number of vertices is necessary to achieve comparable coefficients
224
+ mesh = mask_to_mesh_fixed_vertices(mask)
225
+ mesh_proc, params = preprocess_mesh(mesh)
226
+ coeffs, eigvecs, eigvals = surface_to_coefficients(mesh_proc, k=k)
227
+ return coeffs, eigvecs, mesh_proc, params
228
+
229
+ def eigvals(mask, k=100, normalize=False):
230
+ mesh = mask_to_mesh(mask)
231
+ coeffs, eigvecs, eigvals = surface_to_coefficients(mesh, k=k)
232
+ if normalize:
233
+ # Normalize eigenvalues by first non-zero eigenvalue
234
+ # eigvals = eigvals / eigvals[1] if eigvals[1] != 0 else eigvals
235
+ eigvals = eigvals / np.max(eigvals)
236
+ # Drop the first eigenvalue (zero mode) from descriptor since it's trivial
237
+ eigvals = eigvals[1:] # length k-1
238
+ return eigvals
239
+
240
+
241
+ def process(mesh, k=10, threshold=None):
242
+ mesh_proc, params = preprocess_mesh(mesh)
243
+
244
+ # Compute LB coefficients (invariant)
245
+ coeffs, eigvecs, eigvals = surface_to_coefficients(mesh_proc, k=k)
246
+
247
+ # Reconstruct in normalized/aligned space
248
+ reconstructed_vertices_proc = coefficients_to_surface(coeffs, eigvecs, threshold=threshold)
249
+
250
+ # Map reconstruction back to original coordinates
251
+ reconstructed_vertices_orig = inverse_preprocess_mesh(reconstructed_vertices_proc, params)
252
+
253
+ # Build reconstructed mesh
254
+ reconstructed_mesh = mesh.copy()
255
+ reconstructed_mesh.vertices = reconstructed_vertices_orig
256
+
257
+ return coeffs, eigvals, reconstructed_mesh
258
+
259
+
260
+
miblab_ssa/metrics.py ADDED
@@ -0,0 +1,280 @@
1
+ import logging
2
+ import numpy as np
3
+ from skimage import measure
4
+ from scipy.spatial import cKDTree
5
+ import dask
6
+ from dask.diagnostics import ProgressBar
7
+ import dask.array as da
8
+ import psutil
9
+ import zarr
10
+
11
+
12
+ def dice_coefficient(vol_a, vol_b):
13
+ """
14
+ Compute the Dice similarity coefficient between two binary masks.
15
+
16
+ Parameters
17
+ ----------
18
+ mask1 : np.ndarray
19
+ First binary mask (values should be 0 or 1).
20
+ mask2 : np.ndarray
21
+ Second binary mask (values should be 0 or 1).
22
+
23
+ Returns
24
+ -------
25
+ float
26
+ Dice coefficient, ranging from 0 (no overlap) to 1 (perfect overlap).
27
+
28
+ Notes
29
+ -----
30
+ The Dice coefficient is defined as:
31
+ Dice = 2 * |A ∩ B| / (|A| + |B|)
32
+ """
33
+ vol_a = vol_a.astype(bool)
34
+ vol_b = vol_b.astype(bool)
35
+ intersection = np.logical_and(vol_a, vol_b).sum()
36
+ size_a = vol_a.sum()
37
+ size_b = vol_b.sum()
38
+ if size_a + size_b == 0:
39
+ return 1.0
40
+ return 2.0 * intersection / (size_a + size_b)
41
+
42
+ def surface_distances(vol_a, vol_b, spacing=(1.0,1.0,1.0)):
43
+ """
44
+ Compute surface distances (Hausdorff and mean) between two binary volumes.
45
+ Args:
46
+ vol_a, vol_b: binary 3D arrays
47
+ spacing: voxel spacing (dz,dy,dx)
48
+ Returns:
49
+ hausdorff, mean_dist
50
+ """
51
+ # extract meshes
52
+ verts_a, faces_a, _, _ = measure.marching_cubes(vol_a.astype(np.uint8), level=0.5, spacing=spacing)
53
+ verts_b, faces_b, _, _ = measure.marching_cubes(vol_b.astype(np.uint8), level=0.5, spacing=spacing)
54
+
55
+ # build kd-trees
56
+ tree_a = cKDTree(verts_a)
57
+ tree_b = cKDTree(verts_b)
58
+
59
+ # distances from A→B and B→A
60
+ d_ab, _ = tree_b.query(verts_a, k=1)
61
+ d_ba, _ = tree_a.query(verts_b, k=1)
62
+
63
+ hausdorff = max(d_ab.max(), d_ba.max())
64
+ mean_dist = 0.5 * (d_ab.mean() + d_ba.mean())
65
+ return hausdorff, mean_dist
66
+
67
+
68
+ def dice_matrix_in_memory(M:np.ndarray):
69
+ """
70
+ Computes a Dice similarity matrix for all numpy masks in a folder using
71
+ vectorized sparse matrix multiplication.
72
+ """
73
+ # Esure the matrix is 2D
74
+ M = M.reshape((M.shape[0], -1))
75
+
76
+ # Convert from Boolean (True/False) to Integer (1/0)
77
+ # This ensures the dot product counts overlapping voxels.
78
+ M = M.astype(np.int32)
79
+
80
+ # 3. Vectorized Intersection Calculation (Matrix Multiplication)
81
+ # Intersections[i, j] = dot_product(mask_i, mask_j)
82
+ # This replaces the nested loop. M.T means M transpose.
83
+ intersection_matrix = M @ M.T
84
+
85
+ # 4. Compute Dice Score
86
+ # Formula: 2 * (A n B) / (|A| + |B|)
87
+
88
+ # The diagonal of the intersection matrix represents |A n A|, which is just |A| (the volume)
89
+ volumes = intersection_matrix.diagonal()
90
+
91
+ # Broadcasting sum: creates a matrix where cell [i,j] = volume[i] + volume[j]
92
+ volumes_sum_matrix = volumes[:, None] + volumes[None, :]
93
+
94
+ # Avoid division by zero (though volumes shouldn't be 0 for valid masks)
95
+ # If both volumes are 0, Dice is technically 1.0 (empty matches empty),
96
+ # but usually we handle this based on context. Here we use np.errstate to handle specific cases.
97
+ with np.errstate(divide='ignore', invalid='ignore'):
98
+ dice_matrix = (2 * intersection_matrix) / volumes_sum_matrix
99
+
100
+ # Handle NaN cases where volumes_sum_matrix might be 0
101
+ dice_matrix = np.nan_to_num(dice_matrix, nan=1.0)
102
+
103
+ return dice_matrix
104
+
105
+
106
+
107
+
108
+
109
+
110
+ def get_optimal_chunk_size(shape, dtype, target_mb=250):
111
+ """
112
+ Calculates the optimal number of masks per chunk based on the specific dtype size.
113
+ """
114
+ # 1. Dynamically get bytes per voxel based on the dtype argument
115
+ # np.int32 -> 4 bytes
116
+ # np.float64 -> 8 bytes
117
+ # np.bool_ -> 1 byte
118
+ bytes_per_voxel = np.dtype(dtype).itemsize
119
+
120
+ # 2. Calculate size of ONE mask in Megabytes (MB)
121
+ one_mask_bytes = np.prod(shape) * bytes_per_voxel
122
+ one_mask_mb = one_mask_bytes / (1024**2)
123
+
124
+ # 3. Constraint A: Dask Target Size (~250MB)
125
+ if one_mask_mb > target_mb:
126
+ dask_optimal_count = 1
127
+ else:
128
+ dask_optimal_count = int(target_mb / one_mask_mb)
129
+
130
+ # 4. Constraint B: System RAM Safety Net (10% of Available RAM)
131
+ available_ram_mb = psutil.virtual_memory().available / (1024**2)
132
+ safe_ram_limit_mb = available_ram_mb * 0.10
133
+ ram_limited_count = int(safe_ram_limit_mb / one_mask_mb)
134
+
135
+ # 5. Pick the safer number
136
+ final_count = min(dask_optimal_count, ram_limited_count)
137
+
138
+ return max(1, final_count)
139
+
140
+
141
+ def dice_matrix_zarr(zarr_path, chunk_size='auto'):
142
+ """
143
+ Computes Dice similarity matrix with auto-optimized memory chunking.
144
+ """
145
+ # 1. Connect to Zarr
146
+ d_masks = da.from_zarr(zarr_path, component='masks')
147
+
148
+ # 2. Determine Chunk Size
149
+ if chunk_size == 'auto':
150
+ # Note: We pass d_masks.shape[1:] to exclude the 'N' dimension (we just want D,H,W)
151
+ chunk_size = get_optimal_chunk_size(d_masks.shape[1:], dtype=np.int32)
152
+
153
+ print(f"Auto-configured chunk_size: {chunk_size} masks")
154
+
155
+ # 3. Flatten Spatial Dimensions
156
+ d_masks = d_masks.reshape(d_masks.shape[0], -1)
157
+
158
+ # 4. Apply Chunking
159
+ d_masks = d_masks.rechunk({0: chunk_size})
160
+
161
+ # 5. Cast to int32
162
+ d_masks = d_masks.astype(np.int32)
163
+
164
+ # 6. Matrix Multiplication (Lazy)
165
+ intersection_graph = d_masks @ d_masks.T
166
+
167
+ print(f"Computing {d_masks.shape[0]}x{d_masks.shape[0]} Dice matrix...")
168
+ with ProgressBar():
169
+ intersection_matrix = intersection_graph.compute()
170
+
171
+ # 7. Compute Dice Score
172
+ volumes = intersection_matrix.diagonal()
173
+ volumes_sum_matrix = volumes[:, None] + volumes[None, :]
174
+
175
+ with np.errstate(divide='ignore', invalid='ignore'):
176
+ dice = (2 * intersection_matrix) / volumes_sum_matrix
177
+
178
+ return np.nan_to_num(dice, nan=1.0)
179
+
180
+
181
+ def hausdorff_matrix_in_memory(M, chunk_size = 1000): # (n_subjects, n_voxels)
182
+ # Chunk output to produce less and larger tasks, and less files
183
+ # Otherwise dask takes too long to schedule
184
+
185
+ # Convert from Boolean (True/False) to Integer (1/0)
186
+ # This ensures the dot product counts overlapping voxels.
187
+ M = M.astype(np.int32)
188
+
189
+ n = M.shape[0]
190
+ # Build a list of all index pairs in the sorted list that need computing
191
+ # Since the matrix is symmetric only half needs to be computed
192
+ pairs = [(i, j) for i in range(n) for j in range(i, n)]
193
+ # Split the list of index pairs up into chunks
194
+ chunks = [pairs[i:i+chunk_size] for i in range(0, len(pairs), chunk_size)]
195
+
196
+ # Compute dice scores for each chunk in parallel
197
+ logging.info("Hausdorff matrix - scheduling tasks..")
198
+ tasks = [
199
+ dask.delayed(_hausdorff_matrix_chunk)(M, chunk)
200
+ for chunk in chunks
201
+ ]
202
+ logging.info("Hausdorff matrix - computing tasks..")
203
+ with ProgressBar():
204
+ chunks = dask.compute(*tasks)
205
+
206
+ # Gather up all the chunks to build one matrix
207
+ logging.info(f"Hausdorff matrix - building matrix..")
208
+ haus_matrix = np.zeros((n, n), dtype=np.float32)
209
+ for chunk in chunks:
210
+ for (i, j), haus_ij in chunk.items():
211
+ haus_matrix[i, j] = haus_ij
212
+ haus_matrix[j, i] = haus_ij
213
+
214
+ return haus_matrix
215
+
216
+
217
+ def _hausdorff_matrix_chunk(M, pairs):
218
+ chunk = {}
219
+ for (i,j) in pairs:
220
+ # Load masks
221
+ mask_i = M[i, ...].astype(bool)
222
+ mask_j = M[j, ...].astype(bool)
223
+ # Compute metrics
224
+ haus_ij, _ = surface_distances(mask_i, mask_j)
225
+ # Add to results
226
+ chunk[(i, j)] = haus_ij
227
+ return chunk
228
+
229
+
230
+
231
+
232
+ def hausdorff_matrix_zarr(zarr_path: str):
233
+ # 1. Open metadata
234
+ z_root = zarr.open(zarr_path, mode='r')
235
+ n = z_root['masks'].shape[0]
236
+
237
+ logging.info(f"Hausdorff matrix: Scheduling {n} row tasks...")
238
+
239
+ # 2. Schedule one task per row
240
+ # Each task computes the distances for row i from [i to n]
241
+ tasks = [
242
+ dask.delayed(_compute_hausdorff_row)(zarr_path, i, n)
243
+ for i in range(n)
244
+ ]
245
+
246
+ # 3. Compute
247
+ with ProgressBar():
248
+ rows = dask.compute(*tasks)
249
+
250
+ # 4. Assemble
251
+ # 'rows' is now a list of arrays of varying lengths
252
+ haus_matrix = np.zeros((n, n), dtype=np.float32)
253
+ for i, row_values in enumerate(rows):
254
+ # row_values contains distances for [i, i+1, ... n-1]
255
+ haus_matrix[i, i:] = row_values
256
+ haus_matrix[i:, i] = row_values # Mirror to lower triangle
257
+
258
+ return haus_matrix
259
+
260
+ def _compute_hausdorff_row(zarr_path, i, n):
261
+ """Computes all distances for a single row starting from the diagonal."""
262
+ z_masks = zarr.open(zarr_path, mode='r')['masks']
263
+
264
+ # Load mask_i once for the entire row
265
+ mask_i = z_masks[i].astype(bool)
266
+
267
+ # Pre-allocate result for the partial row
268
+ row_len = n - i
269
+ row_results = np.zeros(row_len, dtype=np.float32)
270
+
271
+ for idx, j in enumerate(range(i, n)):
272
+ if i == j:
273
+ row_results[idx] = 0.0
274
+ continue
275
+
276
+ mask_j = z_masks[j].astype(bool)
277
+ h_val, _ = surface_distances(mask_i, mask_j)
278
+ row_results[idx] = h_val
279
+
280
+ return row_results