nystrom-ncut 0.0.1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,22 @@
1
+ from .ncut_pytorch import NCUT
2
+ from .propagation_utils import (
3
+ affinity_from_features,
4
+ propagate_eigenvectors,
5
+ propagate_knn,
6
+ quantile_normalize,
7
+ )
8
+ from .visualize_utils import (
9
+ eigenvector_to_rgb,
10
+ rgb_from_tsne_3d,
11
+ rgb_from_umap_sphere,
12
+ rgb_from_tsne_2d,
13
+ rgb_from_umap_3d,
14
+ rgb_from_umap_2d,
15
+ rgb_from_cosine_tsne_3d,
16
+ rotate_rgb_cube,
17
+ convert_to_lab_color,
18
+ propagate_rgb_color,
19
+ get_mask,
20
+ )
21
+ from .ncut_pytorch import nystrom_ncut, ncut
22
+ from .ncut_pytorch import kway_ncut, axis_align
@@ -0,0 +1,561 @@
1
+ # %%
2
+ import logging
3
+ import math
4
+ from typing import Literal
5
+
6
+ import torch
7
+
8
+ from .nystrom import (
9
+ solve_eig
10
+ )
11
+ from .propagation_utils import (
12
+ run_subgraph_sampling,
13
+ propagate_knn,
14
+ affinity_from_features,
15
+ )
16
+
17
+
18
+ class NCUT:
19
+ """Nystrom Normalized Cut for large scale graph."""
20
+
21
+ def __init__(
22
+ self,
23
+ num_eig: int = 100,
24
+ knn: int = 10,
25
+ affinity_focal_gamma: float = 1.0,
26
+ num_sample: int = 10000,
27
+ sample_method: Literal["farthest", "random"] = "farthest",
28
+ distance: Literal["cosine", "euclidean", "rbf"] = "cosine",
29
+ indirect_connection: bool = False,
30
+ indirect_pca_dim: int = 100,
31
+ device: str = None,
32
+ move_output_to_cpu: bool = False,
33
+ eig_solver: Literal["svd_lowrank", "lobpcg", "svd", "eigh"] = "svd_lowrank",
34
+ normalize_features: bool = None,
35
+ matmul_chunk_size: int = 8096,
36
+ make_orthogonal: bool = False,
37
+ verbose: bool = False,
38
+ ):
39
+ """
40
+ Args:
41
+ num_eig (int): number of top eigenvectors to return
42
+ knn (int): number of KNN for propagating eigenvectors from subgraph to full graph,
43
+ smaller knn result in more sharp eigenvectors.
44
+ affinity_focal_gamma (float): affinity matrix temperature, lower t reduce the not-so-connected edge weights,
45
+ smaller t result in more sharp eigenvectors.
46
+ num_sample (int): number of samples for Nystrom-like approximation,
47
+ reduce only if memory is not enough, increase for better approximation
48
+ sample_method (str): subgraph sampling, ['farthest', 'random'].
49
+ farthest point sampling is recommended for better Nystrom-approximation accuracy
50
+ distance (str): distance metric for affinity matrix, ['cosine', 'euclidean', 'rbf'].
51
+ indirect_connection (bool): include indirect connection in the Nystrom-like approximation
52
+ indirect_pca_dim (int): when compute indirect connection, PCA to reduce the node dimension,
53
+ device (str): device to use for eigen computation,
54
+ move to GPU to speeds up a bit (~5x faster)
55
+ move_output_to_cpu (bool): move output to CPU, set to True if you have memory issue
56
+ eig_solver (str): eigen decompose solver, ['svd_lowrank', 'lobpcg', 'svd', 'eigh'].
57
+ normalize_features (bool): normalize input features before computing affinity matrix,
58
+ default 'None' is True for cosine distance, False for euclidean distance and rbf
59
+ matmul_chunk_size (int): chunk size for large-scale matrix multiplication
60
+ make_orthogonal (bool): make eigenvectors orthogonal post-hoc
61
+ verbose (bool): progress bar
62
+
63
+ Examples:
64
+ >>> from ncut_pytorch import NCUT
65
+ >>> import torch
66
+ >>> features = torch.rand(10000, 100)
67
+ >>> ncut = NCUT(num_eig=20)
68
+ >>> ncut.fit(features)
69
+ >>> eigenvectors, eigenvalues = ncut.transform(features)
70
+ >>> print(eigenvectors.shape, eigenvalues.shape)
71
+ >>> # (10000, 20) (20,)
72
+
73
+ >>> from ncut_pytorch import eigenvector_to_rgb
74
+ >>> # use t-SNE or UMAP to convert eigenvectors to RGB
75
+ >>> X_3d, rgb = eigenvector_to_rgb(eigenvectors, method='tsne_3d')
76
+ >>> print(X_3d.shape, rgb.shape)
77
+ >>> # (10000, 3) (10000, 3)
78
+
79
+ >>> # transform new features
80
+ >>> new_features = torch.rand(500, 100)
81
+ >>> new_eigenvectors, _ = ncut.transform(new_features)
82
+ >>> print(new_eigenvectors.shape)
83
+ >>> # (500, 20)
84
+ """
85
+ self.num_eig = num_eig
86
+ self.num_sample = num_sample
87
+ self.knn = knn
88
+ self.sample_method = sample_method
89
+ self.distance = distance
90
+ self.affinity_focal_gamma = affinity_focal_gamma
91
+ self.indirect_connection = indirect_connection
92
+ self.indirect_pca_dim = indirect_pca_dim
93
+ self.device = device
94
+ self.move_output_to_cpu = move_output_to_cpu
95
+ self.eig_solver = eig_solver
96
+ self.normalize_features = normalize_features
97
+ if self.normalize_features is None:
98
+ if distance in ["cosine"]:
99
+ self.normalize_features = True
100
+ if distance in ["euclidean", "rbf"]:
101
+ self.normalize_features = False
102
+ self.matmul_chunk_size = matmul_chunk_size
103
+ self.make_orthogonal = make_orthogonal
104
+ self.verbose = verbose
105
+
106
+ self.subgraph_eigen_vector = None
107
+ self.eigen_value = None
108
+ self.subgraph_indices = None
109
+ self.subgraph_features = None
110
+
111
+ def fit(self,
112
+ features: torch.Tensor,
113
+ precomputed_sampled_indices: torch.Tensor = None
114
+ ):
115
+ """Fit Nystrom Normalized Cut on the input features.
116
+ Args:
117
+ features (torch.Tensor): input features, shape (n_samples, n_features)
118
+ precomputed_sampled_indices (torch.Tensor): precomputed sampled indices, shape (num_sample,)
119
+ override the sample_method, if not None
120
+ Returns:
121
+ (NCUT): self
122
+ """
123
+ _n = features.shape[0]
124
+ if self.num_sample >= _n:
125
+ logging.info(
126
+ f"NCUT nystrom num_sample is larger than number of input samples, nyström approximation is not needed, setting num_sample={_n} and knn=1"
127
+ )
128
+ self.num_sample = _n
129
+ self.knn = 1
130
+
131
+ # save the eigenvectors solution on the sub-sampled graph, do not propagate to full graph yet
132
+ self.subgraph_eigen_vector, self.eigen_value, self.subgraph_indices = nystrom_ncut(
133
+ features,
134
+ num_eig=self.num_eig,
135
+ num_sample=self.num_sample,
136
+ sample_method=self.sample_method,
137
+ precomputed_sampled_indices=precomputed_sampled_indices,
138
+ distance=self.distance,
139
+ affinity_focal_gamma=self.affinity_focal_gamma,
140
+ indirect_connection=self.indirect_connection,
141
+ indirect_pca_dim=self.indirect_pca_dim,
142
+ device=self.device,
143
+ eig_solver=self.eig_solver,
144
+ normalize_features=self.normalize_features,
145
+ matmul_chunk_size=self.matmul_chunk_size,
146
+ verbose=self.verbose,
147
+ no_propagation=True,
148
+ move_output_to_cpu=self.move_output_to_cpu,
149
+ )
150
+ self.subgraph_features = features[self.subgraph_indices]
151
+ return self
152
+
153
+ def transform(self, features: torch.Tensor, knn: int = None):
154
+ """Transform new features using the fitted Nystrom Normalized Cut.
155
+ Args:
156
+ features (torch.Tensor): new features, shape (n_samples, n_features)
157
+ knn (int): number of KNN for propagating eigenvectors from subgraph to full graph,
158
+ Returns:
159
+ (torch.Tensor): eigen_vectors, shape (n_samples, num_eig)
160
+ (torch.Tensor): eigen_values, sorted in descending order, shape (num_eig,)
161
+ """
162
+
163
+ knn = self.knn if knn is None else knn
164
+
165
+ # propagate eigenvectors from subgraph to full graph
166
+ eigen_vector = propagate_knn(
167
+ self.subgraph_eigen_vector,
168
+ features,
169
+ self.subgraph_features,
170
+ knn,
171
+ distance=self.distance,
172
+ chunk_size=self.matmul_chunk_size,
173
+ device=self.device,
174
+ use_tqdm=self.verbose,
175
+ move_output_to_cpu=self.move_output_to_cpu,
176
+ )
177
+ if self.make_orthogonal:
178
+ eigen_vector = gram_schmidt(eigen_vector)
179
+ return eigen_vector, self.eigen_value
180
+
181
+ def fit_transform(self,
182
+ features: torch.Tensor,
183
+ precomputed_sampled_indices: torch.Tensor = None
184
+ ):
185
+ """
186
+ Args:
187
+ features (torch.Tensor): input features, shape (n_samples, n_features)
188
+ precomputed_sampled_indices (torch.Tensor): precomputed sampled indices, shape (num_sample,)
189
+ override the sample_method, if not None
190
+
191
+ Returns:
192
+ (torch.Tensor): eigen_vectors, shape (n_samples, num_eig)
193
+ (torch.Tensor): eigen_values, sorted in descending order, shape (num_eig,)
194
+ """
195
+ return self.fit(features, precomputed_sampled_indices=precomputed_sampled_indices).transform(features)
196
+
197
+
198
+ def nystrom_ncut(
199
+ features: torch.Tensor,
200
+ num_eig: int = 100,
201
+ num_sample: int = 10000,
202
+ knn: int = 10,
203
+ sample_method: Literal["farthest", "random"] = "farthest",
204
+ precomputed_sampled_indices: torch.Tensor = None,
205
+ distance: Literal["cosine", "euclidean", "rbf"] = "cosine",
206
+ affinity_focal_gamma: float = 1.0,
207
+ indirect_connection: bool = True,
208
+ indirect_pca_dim: int = 100,
209
+ device: str = None,
210
+ eig_solver: Literal["svd_lowrank", "lobpcg", "svd", "eigh"] = "svd_lowrank",
211
+ normalize_features: bool = None,
212
+ matmul_chunk_size: int = 8096,
213
+ make_orthogonal: bool = True,
214
+ verbose: bool = False,
215
+ no_propagation: bool = False,
216
+ move_output_to_cpu: bool = False,
217
+ ):
218
+ """PyTorch implementation of Faster Nystrom Normalized cut.
219
+ Args:
220
+ features (torch.Tensor): feature matrix, shape (n_samples, n_features)
221
+ num_eig (int): default 100, number of top eigenvectors to return
222
+ num_sample (int): default 10000, number of samples for Nystrom-like approximation
223
+ knn (int): default 10, number of KNN for propagating eigenvectors from subgraph to full graph,
224
+ smaller knn will result in more sharp eigenvectors,
225
+ sample_method (str): sample method, 'farthest' (default) or 'random'
226
+ 'farthest' is recommended for better approximation
227
+ precomputed_sampled_indices (torch.Tensor): precomputed sampled indices, shape (num_sample,)
228
+ override the sample_method, if not None
229
+ distance (str): distance metric, 'cosine' (default) or 'euclidean', 'rbf'
230
+ affinity_focal_gamma (float): affinity matrix parameter, lower t reduce the weak edge weights,
231
+ resulting in more sharp eigenvectors, default 1.0
232
+ indirect_connection (bool): include indirect connection in the subgraph, default True
233
+ indirect_pca_dim (int): default 100, PCA dimension to reduce the node dimension, only applied to
234
+ the not sampled nodes, not applied to the sampled nodes
235
+ device (str): device to use for computation, if None, will not change device
236
+ a good practice is to pass features by CPU since it's usually large,
237
+ and move subgraph affinity to GPU to speed up eigenvector computation
238
+ eig_solver (str): eigen decompose solver, 'svd_lowrank' (default), 'lobpcg', 'svd', 'eigh'
239
+ 'svd_lowrank' is recommended for large scale graph, it's the fastest
240
+ they correspond to torch.svd_lowrank, torch.lobpcg, torch.svd, torch.linalg.eigh
241
+ normalize_features (bool): normalize input features before computing affinity matrix,
242
+ default 'None' is True for cosine distance, False for euclidean distance and rbf
243
+ matmul_chunk_size (int): chunk size for matrix multiplication
244
+ large matrix multiplication is chunked to reduce memory usage,
245
+ smaller chunk size will reduce memory usage but slower computation, default 8096
246
+ make_orthogonal (bool): make eigenvectors orthogonal after propagation, default True
247
+ verbose (bool): show progress bar when propagating eigenvectors from subgraph to full graph
248
+ no_propagation (bool): if True, skip the eigenvector propagation step, only return the subgraph eigenvectors
249
+ move_output_to_cpu (bool): move output to CPU, set to True if you have memory issue
250
+ Returns:
251
+ (torch.Tensor): eigenvectors, shape (n_samples, num_eig)
252
+ (torch.Tensor): eigenvalues, sorted in descending order, shape (num_eig,)
253
+ (torch.Tensor): sampled_indices used by Nystrom-like approximation subgraph, shape (num_sample,)
254
+ """
255
+
256
+ # check if features dimension greater than num_eig
257
+ if eig_solver in ["svd_lowrank", "lobpcg"]:
258
+ assert features.shape[0] > (
259
+ num_eig * 2
260
+ ), "number of nodes should be greater than 2*num_eig"
261
+ if eig_solver in ["svd", "eigh"]:
262
+ assert (
263
+ features.shape[0] > num_eig
264
+ ), "number of nodes should be greater than num_eig"
265
+
266
+ assert distance in ["cosine", "euclidean", "rbf"], "distance should be 'cosine', 'euclidean', 'rbf'"
267
+
268
+ if normalize_features:
269
+ # features need to be normalized for affinity matrix computation (cosine distance)
270
+ features = torch.nn.functional.normalize(features, dim=-1)
271
+
272
+ if precomputed_sampled_indices is not None:
273
+ sampled_indices = precomputed_sampled_indices
274
+ else:
275
+ sampled_indices = run_subgraph_sampling(
276
+ features,
277
+ num_sample=num_sample,
278
+ sample_method=sample_method,
279
+ )
280
+
281
+ sampled_features = features[sampled_indices]
282
+ # move subgraph gpu to speed up
283
+ original_device = sampled_features.device
284
+ device = original_device if device is None else device
285
+ sampled_features = sampled_features.to(device)
286
+
287
+ # compute affinity matrix on subgraph
288
+ A = affinity_from_features(
289
+ sampled_features,
290
+ affinity_focal_gamma=affinity_focal_gamma,
291
+ distance=distance,
292
+ )
293
+
294
+ # check if all nodes are sampled, if so, no need for Nystrom approximation
295
+ not_sampled = torch.full((features.shape[0],), True)
296
+ not_sampled[sampled_indices] = False
297
+ _n_not_sampled = not_sampled.sum()
298
+
299
+ if _n_not_sampled == 0:
300
+ # if sampled all nodes, no need for nyström approximation
301
+ eigen_vector, eigen_value = ncut(A, num_eig, eig_solver=eig_solver)
302
+ return eigen_vector, eigen_value, sampled_indices
303
+
304
+ # 1) PCA to reduce the node dimension for the not sampled nodes
305
+ # 2) compute indirect connection on the PC nodes
306
+ if _n_not_sampled > 0 and indirect_connection:
307
+ indirect_pca_dim = min(indirect_pca_dim, *features.shape)
308
+ U, S, V = torch.pca_lowrank(features[not_sampled].T, q=indirect_pca_dim)
309
+ S = S / math.sqrt(_n_not_sampled)
310
+ feature_B_T = U @ torch.diag(S)
311
+ feature_B = feature_B_T.T
312
+ feature_B = feature_B.to(device)
313
+
314
+ B = affinity_from_features(
315
+ sampled_features,
316
+ feature_B,
317
+ affinity_focal_gamma=affinity_focal_gamma,
318
+ distance=distance,
319
+ fill_diagonal=False,
320
+ )
321
+ # P is 1-hop random walk matrix
322
+ B_row = B / B.sum(dim=1, keepdim=True)
323
+ B_col = B / B.sum(dim=0, keepdim=True)
324
+ P = B_row @ B_col.T
325
+ P = (P + P.T) / 2
326
+ # fill diagonal with 0
327
+ P[torch.arange(P.shape[0]), torch.arange(P.shape[0])] = 0
328
+ A = A + P
329
+
330
+ # compute normalized cut on the subgraph
331
+ eigen_vector, eigen_value = ncut(A, num_eig, eig_solver=eig_solver)
332
+ eigen_vector = eigen_vector.to(dtype=features.dtype, device=original_device)
333
+ eigen_value = eigen_value.to(dtype=features.dtype, device=original_device)
334
+
335
+ if no_propagation:
336
+ return eigen_vector, eigen_value, sampled_indices
337
+
338
+ # propagate eigenvectors from subgraph to full graph
339
+ eigen_vector = propagate_knn(
340
+ eigen_vector,
341
+ features,
342
+ sampled_features,
343
+ knn,
344
+ distance=distance,
345
+ chunk_size=matmul_chunk_size,
346
+ device=device,
347
+ use_tqdm=verbose,
348
+ move_output_to_cpu=move_output_to_cpu,
349
+ )
350
+
351
+ # post-hoc orthogonalization
352
+ if make_orthogonal:
353
+ eigen_vector = gram_schmidt(eigen_vector)
354
+
355
+ return eigen_vector, eigen_value, sampled_indices
356
+
357
+
358
+ def normalized_affinity_transform(D: torch.Tensor, affinity_focal_gamma: float):
359
+ """Compute Laplacian-normalized affinity matrix from input features.
360
+
361
+ Args:
362
+ features (torch.Tensor): input features, shape (n_samples, n_features)
363
+ features_B (torch.Tensor, optional): optional, if not None, compute affinity between two features
364
+ affinity_focal_gamma (float): affinity matrix parameter, lower t reduce the edge weights
365
+ on weak connections, default 1.0
366
+ distance (str): distance metric, 'cosine' (default) or 'euclidean', 'rbf'.
367
+ normalize_features (bool): normalize input features before computing affinity matrix
368
+
369
+ Returns:
370
+ (torch.Tensor): affinity matrix, shape (n_samples, n_samples)
371
+ """
372
+ # make sure D is symmetric
373
+ D = (D + D.T) / 2
374
+ A = torch.exp(-D / affinity_focal_gamma)
375
+
376
+ # symmetrical normalization; A = D^(-1/2) A D^(-1/2)
377
+ D = A.sum(dim=-1).detach().clone()
378
+ A /= torch.sqrt(D)[:, None]
379
+ A /= torch.sqrt(D)[None, :]
380
+ return A
381
+
382
+
383
+ def ncut(
384
+ A: torch.Tensor,
385
+ num_eig: int = 100,
386
+ eig_solver: Literal["svd_lowrank", "lobpcg", "svd", "eigh"] = "svd_lowrank",
387
+ ):
388
+ """PyTorch implementation of Normalized cut without Nystrom-like approximation.
389
+
390
+ Args:
391
+ A (torch.Tensor): affinity matrix, shape (n_samples, n_samples)
392
+ num_eig (int): number of eigenvectors to return
393
+ eig_solver (str): eigen decompose solver, ['svd_lowrank', 'lobpcg', 'svd', 'eigh']
394
+
395
+ Returns:
396
+ (torch.Tensor): eigenvectors corresponding to the eigenvalues, shape (n_samples, num_eig)
397
+ (torch.Tensor): eigenvalues of the eigenvectors, sorted in descending order
398
+ """
399
+ # make sure A is symmetric
400
+ A = (A + A.T) / 2
401
+
402
+ # symmetrical normalization; A = D^(-1/2) A D^(-1/2)
403
+ D = A.sum(dim=-1).detach().clone()
404
+ A /= torch.sqrt(D)[:, None]
405
+ A /= torch.sqrt(D)[None, :]
406
+
407
+ # compute eigenvectors
408
+ eigen_vector, eigen_value = solve_eig(A, num_eig, eig_solver)
409
+
410
+ if eigen_value.min() < 0:
411
+ logging.warning(
412
+ "negative eigenvalues detected, please make sure the affinity matrix is positive definite"
413
+ )
414
+
415
+ return eigen_vector, eigen_value
416
+
417
+
418
+ def gram_schmidt(matrix):
419
+ """Orthogonalize a matrix column-wise using the Gram-Schmidt process.
420
+
421
+ Args:
422
+ matrix (torch.Tensor): A matrix to be orthogonalized (m x n).
423
+ the second dimension is orthogonalized
424
+ Returns:
425
+ torch.Tensor: Orthogonalized matrix (m x n).
426
+ """
427
+
428
+ # Get the number of rows (m) and columns (n) of the input matrix
429
+ m, n = matrix.shape
430
+
431
+ # Create an empty matrix to store the orthogonalized columns
432
+ orthogonal_matrix = torch.zeros((m, n), dtype=matrix.dtype)
433
+
434
+ for i in range(n):
435
+ # Start with the i-th column of the input matrix
436
+ vec = matrix[:, i]
437
+
438
+ for j in range(i):
439
+ # Subtract the projection of vec onto the j-th orthogonal column
440
+ proj = torch.dot(orthogonal_matrix[:, j], matrix[:, i]) / torch.dot(
441
+ orthogonal_matrix[:, j], orthogonal_matrix[:, j]
442
+ )
443
+ vec = vec - proj * orthogonal_matrix[:, j]
444
+
445
+ # Store the orthogonalized vector
446
+ orthogonal_matrix[:, i] = vec / torch.norm(vec)
447
+
448
+ return orthogonal_matrix
449
+
450
+
451
+ def correct_rotation(eigen_vector):
452
+ # correct the random rotation (flipping sign) of eigenvectors
453
+ rand_w = torch.ones(
454
+ eigen_vector.shape[0], device=eigen_vector.device, dtype=eigen_vector.dtype
455
+ )
456
+ s = rand_w[None, :] @ eigen_vector
457
+ s = s.sign()
458
+ return eigen_vector * s
459
+
460
+
461
+ # Multiclass Spectral Clustering, SX Yu, J Shi, 2003
462
+ def _discretisation_eigenvector(eigen_vector):
463
+ # Function that discretizes rotated eigenvectors
464
+ n, k = eigen_vector.shape
465
+
466
+ # Find the maximum index along each row
467
+ _, J = torch.max(eigen_vector, dim=1)
468
+ Y = torch.zeros(n, k, device=eigen_vector.device).scatter_(1, J.unsqueeze(1), 1)
469
+
470
+ return Y
471
+
472
+
473
+ def kway_ncut(eigen_vectors: torch.Tensor, max_iter=300, return_rotation=False):
474
+ """Multiclass Spectral Clustering, SX Yu, J Shi, 2003
475
+
476
+ Args:
477
+ eigen_vectors (torch.Tensor): continuous eigenvectors from NCUT, shape (n, k)
478
+ max_iter (int, optional): Maximum number of iterations.
479
+
480
+ Returns:
481
+ torch.Tensor: Discretized eigenvectors, shape (n, k), each row is a one-hot vector.
482
+ """
483
+ # Normalize eigenvectors
484
+ n, k = eigen_vectors.shape
485
+ vm = torch.sqrt(torch.sum(eigen_vectors ** 2, dim=1))
486
+ eigen_vectors = eigen_vectors / vm.unsqueeze(1)
487
+
488
+ # Initialize R matrix with the first column from a random row of EigenVectors
489
+ R = torch.zeros(k, k, device=eigen_vectors.device)
490
+ R[:, 0] = eigen_vectors[torch.randint(0, n, (1,))].squeeze()
491
+
492
+ # Loop to populate R with k orthogonal directions
493
+ c = torch.zeros(n, device=eigen_vectors.device)
494
+ for j in range(1, k):
495
+ c += torch.abs(eigen_vectors @ R[:, j - 1])
496
+ _, i = torch.min(c, dim=0)
497
+ R[:, j] = eigen_vectors[i]
498
+
499
+ # Iterative optimization loop
500
+ last_objective_value = 0
501
+ exit_loop = False
502
+ nb_iterations_discretisation = 0
503
+
504
+ while not exit_loop:
505
+ nb_iterations_discretisation += 1
506
+
507
+ # Discretize the projected eigenvectors
508
+ eigenvectors_discrete = _discretisation_eigenvector(eigen_vectors @ R)
509
+
510
+ # SVD decomposition
511
+ U, S, Vh = torch.linalg.svd(eigenvectors_discrete.T @ eigen_vectors, full_matrices=False)
512
+ V = Vh.T
513
+
514
+ # Compute the Ncut value
515
+ ncut_value = 2 * (n - torch.sum(S))
516
+
517
+ # Check for convergence
518
+ if torch.abs(ncut_value - last_objective_value) < torch.finfo(
519
+ torch.float32).eps or nb_iterations_discretisation > max_iter:
520
+ exit_loop = True
521
+ else:
522
+ last_objective_value = ncut_value
523
+ R = V @ U.T
524
+
525
+ if return_rotation:
526
+ return eigenvectors_discrete, R
527
+
528
+ return eigenvectors_discrete
529
+
530
+
531
+ def axis_align(eigen_vectors, max_iter=300):
532
+ return kway_ncut(eigen_vectors, max_iter=max_iter, return_rotation=True)
533
+
534
+
535
+ ## for backward compatibility ##
536
+
537
+ try:
538
+
539
+ from .propagation_utils import (
540
+ propagate_nearest,
541
+ propagate_eigenvectors,
542
+ quantile_normalize,
543
+ quantile_min_max,
544
+ farthest_point_sampling,
545
+ )
546
+ from .visualize_utils import (
547
+ eigenvector_to_rgb,
548
+ rgb_from_tsne_3d,
549
+ rgb_from_umap_sphere,
550
+ rgb_from_tsne_2d,
551
+ rgb_from_umap_3d,
552
+ rgb_from_umap_2d,
553
+ rotate_rgb_cube,
554
+ convert_to_lab_color,
555
+ _transform_heatmap,
556
+ _clean_mask,
557
+ get_mask,
558
+ )
559
+
560
+ except ImportError:
561
+ print("some of viualization and nystrom_utils are not imported")