nystrom-ncut 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,46 +1,112 @@
1
- # %%
2
1
  import logging
3
- import math
4
- from typing import Literal
2
+ from typing import Literal, Tuple
5
3
 
6
4
  import torch
5
+ import torch.nn.functional as Fn
7
6
 
8
7
  from .nystrom import (
9
- solve_eig
8
+ EigSolverOptions,
9
+ OnlineKernel,
10
+ OnlineNystrom,
11
+ solve_eig,
10
12
  )
11
13
  from .propagation_utils import (
12
- run_subgraph_sampling,
13
- propagate_knn,
14
14
  affinity_from_features,
15
+ run_subgraph_sampling,
15
16
  )
16
17
 
17
18
 
18
- class NCUT:
19
+ DistanceOptions = Literal["cosine", "euclidean", "rbf"]
20
+
21
+
22
+ class LaplacianKernel(OnlineKernel):
23
+ def __init__(
24
+ self,
25
+ affinity_focal_gamma: float,
26
+ distance: DistanceOptions,
27
+ eig_solver: EigSolverOptions,
28
+ ):
29
+ self.affinity_focal_gamma = affinity_focal_gamma
30
+ self.distance: DistanceOptions = distance
31
+ self.eig_solver: EigSolverOptions = eig_solver
32
+
33
+ # Anchor matrices
34
+ self.anchor_features: torch.Tensor = None # [n x d]
35
+ self.A: torch.Tensor = None # [n x n]
36
+ self.Ainv: torch.Tensor = None # [n x n]
37
+
38
+ # Updated matrices
39
+ self.a_r: torch.Tensor = None # [n]
40
+ self.b_r: torch.Tensor = None # [n]
41
+
42
+ def fit(self, features: torch.Tensor) -> None:
43
+ self.anchor_features = features # [n x d]
44
+ self.A = affinity_from_features(
45
+ self.anchor_features, # [n x d]
46
+ affinity_focal_gamma=self.affinity_focal_gamma,
47
+ distance=self.distance,
48
+ ) # [n x n]
49
+ U, L = solve_eig(
50
+ self.A,
51
+ num_eig=features.shape[-1] + 1,
52
+ eig_solver=self.eig_solver,
53
+ ) # [n x (d + 1)], [d + 1]
54
+ self.Ainv = U @ torch.diag(1 / L) @ U.mT # [n x n]
55
+ self.a_r = torch.sum(self.A, dim=-1) # [n]
56
+ self.b_r = torch.zeros_like(self.a_r) # [n]
57
+
58
+ def update(self, features: torch.Tensor) -> torch.Tensor:
59
+ B = affinity_from_features(
60
+ self.anchor_features, # [n x d]
61
+ features, # [m x d]
62
+ affinity_focal_gamma=self.affinity_focal_gamma,
63
+ distance=self.distance,
64
+ ) # [n x m]
65
+ b_r = torch.sum(B, dim=-1) # [n]
66
+ b_c = torch.sum(B, dim=-2) # [m]
67
+ self.b_r = self.b_r + b_r # [n]
68
+
69
+ rowscale = self.a_r + self.b_r # [n]
70
+ colscale = b_c + B.mT @ self.Ainv @ self.b_r # [m]
71
+ scale = (rowscale[:, None] * colscale) ** -0.5 # [n x m]
72
+ return (B * scale).mT # [m x n]
73
+
74
+ def transform(self, features: torch.Tensor = None) -> torch.Tensor:
75
+ rowscale = self.a_r + self.b_r # [n]
76
+ if features is None:
77
+ B = self.A # [n x n]
78
+ colscale = rowscale # [n]
79
+ else:
80
+ B = affinity_from_features(
81
+ self.anchor_features, # [n x d]
82
+ features, # [m x d]
83
+ affinity_focal_gamma=self.affinity_focal_gamma,
84
+ distance=self.distance,
85
+ ) # [n x m]
86
+ b_c = torch.sum(B, dim=-2) # [m]
87
+ colscale = b_c + B.mT @ self.Ainv @ self.b_r # [m]
88
+ scale = (rowscale[:, None] * colscale) ** -0.5 # [n x m]
89
+ return (B * scale).mT # [m x n]
90
+
91
+
92
+ class NCUT(OnlineNystrom):
19
93
  """Nystrom Normalized Cut for large scale graph."""
20
94
 
21
95
  def __init__(
22
96
  self,
23
- num_eig: int = 100,
24
- knn: int = 10,
97
+ n_components: int = 100,
25
98
  affinity_focal_gamma: float = 1.0,
26
99
  num_sample: int = 10000,
27
100
  sample_method: Literal["farthest", "random"] = "farthest",
28
- distance: Literal["cosine", "euclidean", "rbf"] = "cosine",
29
- indirect_connection: bool = False,
30
- indirect_pca_dim: int = 100,
31
- device: str = None,
32
- move_output_to_cpu: bool = False,
33
- eig_solver: Literal["svd_lowrank", "lobpcg", "svd", "eigh"] = "svd_lowrank",
101
+ distance: DistanceOptions = "cosine",
102
+ eig_solver: EigSolverOptions = "svd_lowrank",
34
103
  normalize_features: bool = None,
35
- matmul_chunk_size: int = 8096,
36
- make_orthogonal: bool = False,
37
- verbose: bool = False,
104
+ move_output_to_cpu: bool = False,
105
+ chunk_size: int = 8192,
38
106
  ):
39
107
  """
40
108
  Args:
41
- num_eig (int): number of top eigenvectors to return
42
- knn (int): number of KNN for propagating eigenvectors from subgraph to full graph,
43
- smaller knn result in more sharp eigenvectors.
109
+ n_components (int): number of top eigenvectors to return
44
110
  affinity_focal_gamma (float): affinity matrix temperature, lower t reduce the not-so-connected edge weights,
45
111
  smaller t result in more sharp eigenvectors.
46
112
  num_sample (int): number of samples for Nystrom-like approximation,
@@ -48,140 +114,101 @@ class NCUT:
48
114
  sample_method (str): subgraph sampling, ['farthest', 'random'].
49
115
  farthest point sampling is recommended for better Nystrom-approximation accuracy
50
116
  distance (str): distance metric for affinity matrix, ['cosine', 'euclidean', 'rbf'].
51
- indirect_connection (bool): include indirect connection in the Nystrom-like approximation
52
- indirect_pca_dim (int): when compute indirect connection, PCA to reduce the node dimension,
53
- device (str): device to use for eigen computation,
54
- move to GPU to speeds up a bit (~5x faster)
55
- move_output_to_cpu (bool): move output to CPU, set to True if you have memory issue
56
117
  eig_solver (str): eigen decompose solver, ['svd_lowrank', 'lobpcg', 'svd', 'eigh'].
57
118
  normalize_features (bool): normalize input features before computing affinity matrix,
58
119
  default 'None' is True for cosine distance, False for euclidean distance and rbf
59
- matmul_chunk_size (int): chunk size for large-scale matrix multiplication
60
- make_orthogonal (bool): make eigenvectors orthogonal post-hoc
61
- verbose (bool): progress bar
62
-
63
- Examples:
64
- >>> from ncut_pytorch import NCUT
65
- >>> import torch
66
- >>> features = torch.rand(10000, 100)
67
- >>> ncut = NCUT(num_eig=20)
68
- >>> ncut.fit(features)
69
- >>> eigenvectors, eigenvalues = ncut.transform(features)
70
- >>> print(eigenvectors.shape, eigenvalues.shape)
71
- >>> # (10000, 20) (20,)
72
-
73
- >>> from ncut_pytorch import eigenvector_to_rgb
74
- >>> # use t-SNE or UMAP to convert eigenvectors to RGB
75
- >>> X_3d, rgb = eigenvector_to_rgb(eigenvectors, method='tsne_3d')
76
- >>> print(X_3d.shape, rgb.shape)
77
- >>> # (10000, 3) (10000, 3)
78
-
79
- >>> # transform new features
80
- >>> new_features = torch.rand(500, 100)
81
- >>> new_eigenvectors, _ = ncut.transform(new_features)
82
- >>> print(new_eigenvectors.shape)
83
- >>> # (500, 20)
120
+ move_output_to_cpu (bool): move output to CPU, set to True if you have memory issue
121
+ chunk_size (int): chunk size for large-scale matrix multiplication
84
122
  """
85
- self.num_eig = num_eig
123
+ OnlineNystrom.__init__(
124
+ self,
125
+ n_components=n_components,
126
+ kernel=LaplacianKernel(affinity_focal_gamma, distance, eig_solver),
127
+ eig_solver=eig_solver,
128
+ chunk_size=chunk_size,
129
+ )
86
130
  self.num_sample = num_sample
87
- self.knn = knn
88
131
  self.sample_method = sample_method
89
132
  self.distance = distance
90
- self.affinity_focal_gamma = affinity_focal_gamma
91
- self.indirect_connection = indirect_connection
92
- self.indirect_pca_dim = indirect_pca_dim
93
- self.device = device
94
- self.move_output_to_cpu = move_output_to_cpu
95
- self.eig_solver = eig_solver
96
133
  self.normalize_features = normalize_features
97
134
  if self.normalize_features is None:
98
135
  if distance in ["cosine"]:
99
136
  self.normalize_features = True
100
137
  if distance in ["euclidean", "rbf"]:
101
138
  self.normalize_features = False
102
- self.matmul_chunk_size = matmul_chunk_size
103
- self.make_orthogonal = make_orthogonal
104
- self.verbose = verbose
105
-
106
- self.subgraph_eigen_vector = None
107
- self.eigen_value = None
108
- self.subgraph_indices = None
109
- self.subgraph_features = None
110
-
111
- def fit(self,
112
- features: torch.Tensor,
113
- precomputed_sampled_indices: torch.Tensor = None
114
- ):
115
- """Fit Nystrom Normalized Cut on the input features.
116
- Args:
117
- features (torch.Tensor): input features, shape (n_samples, n_features)
118
- precomputed_sampled_indices (torch.Tensor): precomputed sampled indices, shape (num_sample,)
119
- override the sample_method, if not None
120
- Returns:
121
- (NCUT): self
122
- """
139
+
140
+ self.move_output_to_cpu = move_output_to_cpu
141
+ self.chunk_size = chunk_size
142
+
143
+ def _fit_helper(
144
+ self,
145
+ features: torch.Tensor,
146
+ precomputed_sampled_indices: torch.Tensor,
147
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
123
148
  _n = features.shape[0]
124
149
  if self.num_sample >= _n:
125
150
  logging.info(
126
- f"NCUT nystrom num_sample is larger than number of input samples, nyström approximation is not needed, setting num_sample={_n} and knn=1"
151
+ f"NCUT nystrom num_sample is larger than number of input samples, nyström approximation is not needed, setting num_sample={_n}"
127
152
  )
128
153
  self.num_sample = _n
129
- self.knn = 1
130
-
131
- # save the eigenvectors solution on the sub-sampled graph, do not propagate to full graph yet
132
- self.subgraph_eigen_vector, self.eigen_value, self.subgraph_indices = nystrom_ncut(
133
- features,
134
- num_eig=self.num_eig,
135
- num_sample=self.num_sample,
136
- sample_method=self.sample_method,
137
- precomputed_sampled_indices=precomputed_sampled_indices,
138
- distance=self.distance,
139
- affinity_focal_gamma=self.affinity_focal_gamma,
140
- indirect_connection=self.indirect_connection,
141
- indirect_pca_dim=self.indirect_pca_dim,
142
- device=self.device,
143
- eig_solver=self.eig_solver,
144
- normalize_features=self.normalize_features,
145
- matmul_chunk_size=self.matmul_chunk_size,
146
- verbose=self.verbose,
147
- no_propagation=True,
148
- move_output_to_cpu=self.move_output_to_cpu,
149
- )
150
- self.subgraph_features = features[self.subgraph_indices]
151
- return self
152
154
 
153
- def transform(self, features: torch.Tensor, knn: int = None):
154
- """Transform new features using the fitted Nystrom Normalized Cut.
155
+ # check if features dimension greater than num_eig
156
+ if self.eig_solver in ["svd_lowrank", "lobpcg"]:
157
+ assert (
158
+ _n >= self.n_components * 2
159
+ ), "number of nodes should be greater than 2*num_eig"
160
+ elif self.eig_solver in ["svd", "eigh"]:
161
+ assert (
162
+ _n >= self.n_components
163
+ ), "number of nodes should be greater than num_eig"
164
+
165
+ assert self.distance in ["cosine", "euclidean", "rbf"], "distance should be 'cosine', 'euclidean', 'rbf'"
166
+
167
+ if self.normalize_features:
168
+ # features need to be normalized for affinity matrix computation (cosine distance)
169
+ features = torch.nn.functional.normalize(features, dim=-1)
170
+
171
+ if precomputed_sampled_indices is not None:
172
+ sampled_indices = precomputed_sampled_indices
173
+ else:
174
+ sampled_indices = run_subgraph_sampling(
175
+ features,
176
+ num_sample=self.num_sample,
177
+ sample_method=self.sample_method,
178
+ )
179
+ sampled_features = features[sampled_indices]
180
+ OnlineNystrom.fit(self, sampled_features)
181
+
182
+ _n_not_sampled = _n - len(sampled_features)
183
+ if _n_not_sampled > 0:
184
+ unsampled_indices = torch.full((_n,), True, device=features.device).scatter_(0, sampled_indices, False)
185
+ unsampled_features = features[unsampled_indices]
186
+ V_unsampled, _ = OnlineNystrom.update(self, unsampled_features)
187
+ else:
188
+ unsampled_indices = V_unsampled = None
189
+ return unsampled_indices, V_unsampled
190
+
191
+ def fit(
192
+ self,
193
+ features: torch.Tensor,
194
+ precomputed_sampled_indices: torch.Tensor = None,
195
+ ):
196
+ """Fit Nystrom Normalized Cut on the input features.
155
197
  Args:
156
- features (torch.Tensor): new features, shape (n_samples, n_features)
157
- knn (int): number of KNN for propagating eigenvectors from subgraph to full graph,
198
+ features (torch.Tensor): input features, shape (n_samples, n_features)
199
+ precomputed_sampled_indices (torch.Tensor): precomputed sampled indices, shape (num_sample,)
200
+ override the sample_method, if not None
158
201
  Returns:
159
- (torch.Tensor): eigen_vectors, shape (n_samples, num_eig)
160
- (torch.Tensor): eigen_values, sorted in descending order, shape (num_eig,)
202
+ (NCUT): self
161
203
  """
204
+ NCUT._fit_helper(self, features, precomputed_sampled_indices)
205
+ return self
162
206
 
163
- knn = self.knn if knn is None else knn
164
-
165
- # propagate eigenvectors from subgraph to full graph
166
- eigen_vector = propagate_knn(
167
- self.subgraph_eigen_vector,
168
- features,
169
- self.subgraph_features,
170
- knn,
171
- distance=self.distance,
172
- chunk_size=self.matmul_chunk_size,
173
- device=self.device,
174
- use_tqdm=self.verbose,
175
- move_output_to_cpu=self.move_output_to_cpu,
176
- )
177
- if self.make_orthogonal:
178
- eigen_vector = gram_schmidt(eigen_vector)
179
- return eigen_vector, self.eigen_value
180
-
181
- def fit_transform(self,
182
- features: torch.Tensor,
183
- precomputed_sampled_indices: torch.Tensor = None
184
- ):
207
+ def fit_transform(
208
+ self,
209
+ features: torch.Tensor,
210
+ precomputed_sampled_indices: torch.Tensor = None,
211
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
185
212
  """
186
213
  Args:
187
214
  features (torch.Tensor): input features, shape (n_samples, n_features)
@@ -192,285 +219,19 @@ class NCUT:
192
219
  (torch.Tensor): eigen_vectors, shape (n_samples, num_eig)
193
220
  (torch.Tensor): eigen_values, sorted in descending order, shape (num_eig,)
194
221
  """
195
- return self.fit(features, precomputed_sampled_indices=precomputed_sampled_indices).transform(features)
196
-
197
-
198
- def nystrom_ncut(
199
- features: torch.Tensor,
200
- num_eig: int = 100,
201
- num_sample: int = 10000,
202
- knn: int = 10,
203
- sample_method: Literal["farthest", "random"] = "farthest",
204
- precomputed_sampled_indices: torch.Tensor = None,
205
- distance: Literal["cosine", "euclidean", "rbf"] = "cosine",
206
- affinity_focal_gamma: float = 1.0,
207
- indirect_connection: bool = True,
208
- indirect_pca_dim: int = 100,
209
- device: str = None,
210
- eig_solver: Literal["svd_lowrank", "lobpcg", "svd", "eigh"] = "svd_lowrank",
211
- normalize_features: bool = None,
212
- matmul_chunk_size: int = 8096,
213
- make_orthogonal: bool = True,
214
- verbose: bool = False,
215
- no_propagation: bool = False,
216
- move_output_to_cpu: bool = False,
217
- ):
218
- """PyTorch implementation of Faster Nystrom Normalized cut.
219
- Args:
220
- features (torch.Tensor): feature matrix, shape (n_samples, n_features)
221
- num_eig (int): default 100, number of top eigenvectors to return
222
- num_sample (int): default 10000, number of samples for Nystrom-like approximation
223
- knn (int): default 10, number of KNN for propagating eigenvectors from subgraph to full graph,
224
- smaller knn will result in more sharp eigenvectors,
225
- sample_method (str): sample method, 'farthest' (default) or 'random'
226
- 'farthest' is recommended for better approximation
227
- precomputed_sampled_indices (torch.Tensor): precomputed sampled indices, shape (num_sample,)
228
- override the sample_method, if not None
229
- distance (str): distance metric, 'cosine' (default) or 'euclidean', 'rbf'
230
- affinity_focal_gamma (float): affinity matrix parameter, lower t reduce the weak edge weights,
231
- resulting in more sharp eigenvectors, default 1.0
232
- indirect_connection (bool): include indirect connection in the subgraph, default True
233
- indirect_pca_dim (int): default 100, PCA dimension to reduce the node dimension, only applied to
234
- the not sampled nodes, not applied to the sampled nodes
235
- device (str): device to use for computation, if None, will not change device
236
- a good practice is to pass features by CPU since it's usually large,
237
- and move subgraph affinity to GPU to speed up eigenvector computation
238
- eig_solver (str): eigen decompose solver, 'svd_lowrank' (default), 'lobpcg', 'svd', 'eigh'
239
- 'svd_lowrank' is recommended for large scale graph, it's the fastest
240
- they correspond to torch.svd_lowrank, torch.lobpcg, torch.svd, torch.linalg.eigh
241
- normalize_features (bool): normalize input features before computing affinity matrix,
242
- default 'None' is True for cosine distance, False for euclidean distance and rbf
243
- matmul_chunk_size (int): chunk size for matrix multiplication
244
- large matrix multiplication is chunked to reduce memory usage,
245
- smaller chunk size will reduce memory usage but slower computation, default 8096
246
- make_orthogonal (bool): make eigenvectors orthogonal after propagation, default True
247
- verbose (bool): show progress bar when propagating eigenvectors from subgraph to full graph
248
- no_propagation (bool): if True, skip the eigenvector propagation step, only return the subgraph eigenvectors
249
- move_output_to_cpu (bool): move output to CPU, set to True if you have memory issue
250
- Returns:
251
- (torch.Tensor): eigenvectors, shape (n_samples, num_eig)
252
- (torch.Tensor): eigenvalues, sorted in descending order, shape (num_eig,)
253
- (torch.Tensor): sampled_indices used by Nystrom-like approximation subgraph, shape (num_sample,)
254
- """
255
-
256
- # check if features dimension greater than num_eig
257
- if eig_solver in ["svd_lowrank", "lobpcg"]:
258
- assert features.shape[0] > (
259
- num_eig * 2
260
- ), "number of nodes should be greater than 2*num_eig"
261
- if eig_solver in ["svd", "eigh"]:
262
- assert (
263
- features.shape[0] > num_eig
264
- ), "number of nodes should be greater than num_eig"
265
-
266
- assert distance in ["cosine", "euclidean", "rbf"], "distance should be 'cosine', 'euclidean', 'rbf'"
267
-
268
- if normalize_features:
269
- # features need to be normalized for affinity matrix computation (cosine distance)
270
- features = torch.nn.functional.normalize(features, dim=-1)
271
-
272
- if precomputed_sampled_indices is not None:
273
- sampled_indices = precomputed_sampled_indices
274
- else:
275
- sampled_indices = run_subgraph_sampling(
276
- features,
277
- num_sample=num_sample,
278
- sample_method=sample_method,
279
- )
280
-
281
- sampled_features = features[sampled_indices]
282
- # move subgraph gpu to speed up
283
- original_device = sampled_features.device
284
- device = original_device if device is None else device
285
- sampled_features = sampled_features.to(device)
286
-
287
- # compute affinity matrix on subgraph
288
- A = affinity_from_features(
289
- sampled_features,
290
- affinity_focal_gamma=affinity_focal_gamma,
291
- distance=distance,
292
- )
293
-
294
- # check if all nodes are sampled, if so, no need for Nystrom approximation
295
- not_sampled = torch.full((features.shape[0],), True)
296
- not_sampled[sampled_indices] = False
297
- _n_not_sampled = not_sampled.sum()
298
-
299
- if _n_not_sampled == 0:
300
- # if sampled all nodes, no need for nyström approximation
301
- eigen_vector, eigen_value = ncut(A, num_eig, eig_solver=eig_solver)
302
- return eigen_vector, eigen_value, sampled_indices
303
-
304
- # 1) PCA to reduce the node dimension for the not sampled nodes
305
- # 2) compute indirect connection on the PC nodes
306
- if _n_not_sampled > 0 and indirect_connection:
307
- indirect_pca_dim = min(indirect_pca_dim, *features.shape)
308
- U, S, V = torch.pca_lowrank(features[not_sampled].T, q=indirect_pca_dim)
309
- S = S / math.sqrt(_n_not_sampled)
310
- feature_B_T = U @ torch.diag(S)
311
- feature_B = feature_B_T.T
312
- feature_B = feature_B.to(device)
313
-
314
- B = affinity_from_features(
315
- sampled_features,
316
- feature_B,
317
- affinity_focal_gamma=affinity_focal_gamma,
318
- distance=distance,
319
- fill_diagonal=False,
320
- )
321
- # P is 1-hop random walk matrix
322
- B_row = B / B.sum(dim=1, keepdim=True)
323
- B_col = B / B.sum(dim=0, keepdim=True)
324
- P = B_row @ B_col.T
325
- P = (P + P.T) / 2
326
- # fill diagonal with 0
327
- P[torch.arange(P.shape[0]), torch.arange(P.shape[0])] = 0
328
- A = A + P
329
-
330
- # compute normalized cut on the subgraph
331
- eigen_vector, eigen_value = ncut(A, num_eig, eig_solver=eig_solver)
332
- eigen_vector = eigen_vector.to(dtype=features.dtype, device=original_device)
333
- eigen_value = eigen_value.to(dtype=features.dtype, device=original_device)
334
-
335
- if no_propagation:
336
- return eigen_vector, eigen_value, sampled_indices
337
-
338
- # propagate eigenvectors from subgraph to full graph
339
- eigen_vector = propagate_knn(
340
- eigen_vector,
341
- features,
342
- sampled_features,
343
- knn,
344
- distance=distance,
345
- chunk_size=matmul_chunk_size,
346
- device=device,
347
- use_tqdm=verbose,
348
- move_output_to_cpu=move_output_to_cpu,
349
- )
350
-
351
- # post-hoc orthogonalization
352
- if make_orthogonal:
353
- eigen_vector = gram_schmidt(eigen_vector)
354
-
355
- return eigen_vector, eigen_value, sampled_indices
356
-
357
-
358
- def normalized_affinity_transform(D: torch.Tensor, affinity_focal_gamma: float):
359
- """Compute Laplacian-normalized affinity matrix from input features.
360
-
361
- Args:
362
- features (torch.Tensor): input features, shape (n_samples, n_features)
363
- features_B (torch.Tensor, optional): optional, if not None, compute affinity between two features
364
- affinity_focal_gamma (float): affinity matrix parameter, lower t reduce the edge weights
365
- on weak connections, default 1.0
366
- distance (str): distance metric, 'cosine' (default) or 'euclidean', 'rbf'.
367
- normalize_features (bool): normalize input features before computing affinity matrix
368
-
369
- Returns:
370
- (torch.Tensor): affinity matrix, shape (n_samples, n_samples)
371
- """
372
- # make sure D is symmetric
373
- D = (D + D.T) / 2
374
- A = torch.exp(-D / affinity_focal_gamma)
375
-
376
- # symmetrical normalization; A = D^(-1/2) A D^(-1/2)
377
- D = A.sum(dim=-1).detach().clone()
378
- A /= torch.sqrt(D)[:, None]
379
- A /= torch.sqrt(D)[None, :]
380
- return A
381
-
382
-
383
- def ncut(
384
- A: torch.Tensor,
385
- num_eig: int = 100,
386
- eig_solver: Literal["svd_lowrank", "lobpcg", "svd", "eigh"] = "svd_lowrank",
387
- ):
388
- """PyTorch implementation of Normalized cut without Nystrom-like approximation.
389
-
390
- Args:
391
- A (torch.Tensor): affinity matrix, shape (n_samples, n_samples)
392
- num_eig (int): number of eigenvectors to return
393
- eig_solver (str): eigen decompose solver, ['svd_lowrank', 'lobpcg', 'svd', 'eigh']
394
-
395
- Returns:
396
- (torch.Tensor): eigenvectors corresponding to the eigenvalues, shape (n_samples, num_eig)
397
- (torch.Tensor): eigenvalues of the eigenvectors, sorted in descending order
398
- """
399
- # make sure A is symmetric
400
- A = (A + A.T) / 2
401
-
402
- # symmetrical normalization; A = D^(-1/2) A D^(-1/2)
403
- D = A.sum(dim=-1).detach().clone()
404
- A /= torch.sqrt(D)[:, None]
405
- A /= torch.sqrt(D)[None, :]
406
-
407
- # compute eigenvectors
408
- eigen_vector, eigen_value = solve_eig(A, num_eig, eig_solver)
222
+ unsampled_indices, V_unsampled = NCUT._fit_helper(self, features, precomputed_sampled_indices)
223
+ V_sampled, L = OnlineNystrom.transform(self)
409
224
 
410
- if eigen_value.min() < 0:
411
- logging.warning(
412
- "negative eigenvalues detected, please make sure the affinity matrix is positive definite"
413
- )
414
-
415
- return eigen_vector, eigen_value
416
-
417
-
418
- def gram_schmidt(matrix):
419
- """Orthogonalize a matrix column-wise using the Gram-Schmidt process.
420
-
421
- Args:
422
- matrix (torch.Tensor): A matrix to be orthogonalized (m x n).
423
- the second dimension is orthogonalized
424
- Returns:
425
- torch.Tensor: Orthogonalized matrix (m x n).
426
- """
427
-
428
- # Get the number of rows (m) and columns (n) of the input matrix
429
- m, n = matrix.shape
430
-
431
- # Create an empty matrix to store the orthogonalized columns
432
- orthogonal_matrix = torch.zeros((m, n), dtype=matrix.dtype)
433
-
434
- for i in range(n):
435
- # Start with the i-th column of the input matrix
436
- vec = matrix[:, i]
437
-
438
- for j in range(i):
439
- # Subtract the projection of vec onto the j-th orthogonal column
440
- proj = torch.dot(orthogonal_matrix[:, j], matrix[:, i]) / torch.dot(
441
- orthogonal_matrix[:, j], orthogonal_matrix[:, j]
442
- )
443
- vec = vec - proj * orthogonal_matrix[:, j]
444
-
445
- # Store the orthogonalized vector
446
- orthogonal_matrix[:, i] = vec / torch.norm(vec)
447
-
448
- return orthogonal_matrix
449
-
450
-
451
- def correct_rotation(eigen_vector):
452
- # correct the random rotation (flipping sign) of eigenvectors
453
- rand_w = torch.ones(
454
- eigen_vector.shape[0], device=eigen_vector.device, dtype=eigen_vector.dtype
455
- )
456
- s = rand_w[None, :] @ eigen_vector
457
- s = s.sign()
458
- return eigen_vector * s
459
-
460
-
461
- # Multiclass Spectral Clustering, SX Yu, J Shi, 2003
462
- def _discretisation_eigenvector(eigen_vector):
463
- # Function that discretizes rotated eigenvectors
464
- n, k = eigen_vector.shape
465
-
466
- # Find the maximum index along each row
467
- _, J = torch.max(eigen_vector, dim=1)
468
- Y = torch.zeros(n, k, device=eigen_vector.device).scatter_(1, J.unsqueeze(1), 1)
469
-
470
- return Y
225
+ if unsampled_indices is not None:
226
+ V = torch.zeros((len(unsampled_indices), self.n_components), device=features.device)
227
+ V[~unsampled_indices] = V_sampled
228
+ V[unsampled_indices] = V_unsampled
229
+ else:
230
+ V = V_sampled
231
+ return V, L
471
232
 
472
233
 
473
- def kway_ncut(eigen_vectors: torch.Tensor, max_iter=300, return_rotation=False):
234
+ def axis_align(eigen_vectors: torch.Tensor, max_iter=300):
474
235
  """Multiclass Spectral Clustering, SX Yu, J Shi, 2003
475
236
 
476
237
  Args:
@@ -482,80 +243,36 @@ def kway_ncut(eigen_vectors: torch.Tensor, max_iter=300, return_rotation=False):
482
243
  """
483
244
  # Normalize eigenvectors
484
245
  n, k = eigen_vectors.shape
485
- vm = torch.sqrt(torch.sum(eigen_vectors ** 2, dim=1))
486
- eigen_vectors = eigen_vectors / vm.unsqueeze(1)
246
+ eigen_vectors = Fn.normalize(eigen_vectors, p=2, dim=-1)
487
247
 
488
248
  # Initialize R matrix with the first column from a random row of EigenVectors
489
- R = torch.zeros(k, k, device=eigen_vectors.device)
490
- R[:, 0] = eigen_vectors[torch.randint(0, n, (1,))].squeeze()
249
+ R = torch.empty((k, k), device=eigen_vectors.device)
250
+ R[0] = eigen_vectors[torch.randint(0, n, (1,))].squeeze()
491
251
 
492
252
  # Loop to populate R with k orthogonal directions
493
253
  c = torch.zeros(n, device=eigen_vectors.device)
494
- for j in range(1, k):
495
- c += torch.abs(eigen_vectors @ R[:, j - 1])
496
- _, i = torch.min(c, dim=0)
497
- R[:, j] = eigen_vectors[i]
254
+ for i in range(1, k):
255
+ c += torch.abs(eigen_vectors @ R[i - 1])
256
+ R[i] = eigen_vectors[torch.argmin(c, dim=0)]
498
257
 
499
258
  # Iterative optimization loop
500
- last_objective_value = 0
501
- exit_loop = False
502
- nb_iterations_discretisation = 0
259
+ eps = torch.finfo(torch.float32).eps
260
+ prev_objective = torch.inf
261
+ for _ in range(max_iter):
262
+ # Discretize the projected eigenvectors
263
+ idx = torch.argmax(eigen_vectors @ R.mT, dim=-1)
264
+ M = torch.zeros((k, k)).index_add_(0, idx, eigen_vectors)
503
265
 
504
- while not exit_loop:
505
- nb_iterations_discretisation += 1
266
+ # Compute the NCut value
267
+ objective = torch.norm(M)
506
268
 
507
- # Discretize the projected eigenvectors
508
- eigenvectors_discrete = _discretisation_eigenvector(eigen_vectors @ R)
269
+ # Check for convergence
270
+ if torch.abs(objective - prev_objective) < eps:
271
+ break
272
+ prev_objective = objective
509
273
 
510
274
  # SVD decomposition
511
- U, S, Vh = torch.linalg.svd(eigenvectors_discrete.T @ eigen_vectors, full_matrices=False)
512
- V = Vh.T
275
+ U, S, Vh = torch.linalg.svd(M, full_matrices=False)
276
+ R = U @ Vh
513
277
 
514
- # Compute the Ncut value
515
- ncut_value = 2 * (n - torch.sum(S))
516
-
517
- # Check for convergence
518
- if torch.abs(ncut_value - last_objective_value) < torch.finfo(
519
- torch.float32).eps or nb_iterations_discretisation > max_iter:
520
- exit_loop = True
521
- else:
522
- last_objective_value = ncut_value
523
- R = V @ U.T
524
-
525
- if return_rotation:
526
- return eigenvectors_discrete, R
527
-
528
- return eigenvectors_discrete
529
-
530
-
531
- def axis_align(eigen_vectors, max_iter=300):
532
- return kway_ncut(eigen_vectors, max_iter=max_iter, return_rotation=True)
533
-
534
-
535
- ## for backward compatibility ##
536
-
537
- try:
538
-
539
- from .propagation_utils import (
540
- propagate_nearest,
541
- propagate_eigenvectors,
542
- quantile_normalize,
543
- quantile_min_max,
544
- farthest_point_sampling,
545
- )
546
- from .visualize_utils import (
547
- eigenvector_to_rgb,
548
- rgb_from_tsne_3d,
549
- rgb_from_umap_sphere,
550
- rgb_from_tsne_2d,
551
- rgb_from_umap_3d,
552
- rgb_from_umap_2d,
553
- rotate_rgb_cube,
554
- convert_to_lab_color,
555
- _transform_heatmap,
556
- _clean_mask,
557
- get_mask,
558
- )
559
-
560
- except ImportError:
561
- print("some of viualization and nystrom_utils are not imported")
278
+ return Fn.one_hot(idx, num_classes=k).to(torch.float), R