nystrom-ncut 0.0.1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,241 @@
1
+ import logging
2
+ from typing import Literal, Tuple
3
+
4
+ import torch
5
+
6
+ from .nystrom import (
7
+ EigSolverOptions,
8
+ OnlineKernel,
9
+ OnlineNystrom,
10
+ solve_eig,
11
+ )
12
+ from .propagation_utils import (
13
+ affinity_from_features,
14
+ run_subgraph_sampling,
15
+ )
16
+
17
+
18
+ DistanceOptions = Literal["cosine", "euclidean", "rbf"]
19
+
20
+
21
+ class LaplacianKernel(OnlineKernel):
22
+ def __init__(
23
+ self,
24
+ affinity_focal_gamma: float,
25
+ distance: DistanceOptions,
26
+ eig_solver: EigSolverOptions,
27
+ ):
28
+ self.affinity_focal_gamma = affinity_focal_gamma
29
+ self.distance: DistanceOptions = distance
30
+ self.eig_solver: EigSolverOptions = eig_solver
31
+
32
+ # Anchor matrices
33
+ self.anchor_features: torch.Tensor = None # [n x d]
34
+ self.A: torch.Tensor = None # [n x n]
35
+ self.Ainv: torch.Tensor = None # [n x n]
36
+
37
+ # Updated matrices
38
+ self.a_r: torch.Tensor = None # [n]
39
+ self.b_r: torch.Tensor = None # [n]
40
+
41
+ def fit(self, features: torch.Tensor) -> None:
42
+ self.anchor_features = features # [n x d]
43
+ self.A = affinity_from_features(
44
+ self.anchor_features, # [n x d]
45
+ affinity_focal_gamma=self.affinity_focal_gamma,
46
+ distance=self.distance,
47
+ fill_diagonal=False,
48
+ ) # [n x n]
49
+ U, L = solve_eig(
50
+ self.A,
51
+ num_eig=features.shape[-1] + 1,
52
+ eig_solver=self.eig_solver,
53
+ ) # [n x (d + 1)], [d + 1]
54
+ self.Ainv = U @ torch.diag(1 / L) @ U.mT # [n x n]
55
+ self.a_r = torch.sum(self.A, dim=-1) # [n]
56
+ self.b_r = torch.zeros_like(self.a_r) # [n]
57
+
58
+ def update(self, features: torch.Tensor) -> torch.Tensor:
59
+ B = affinity_from_features(
60
+ self.anchor_features, # [n x d]
61
+ features, # [m x d]
62
+ affinity_focal_gamma=self.affinity_focal_gamma,
63
+ distance=self.distance,
64
+ fill_diagonal=False,
65
+ ) # [n x m]
66
+ b_r = torch.sum(B, dim=-1) # [n]
67
+ b_c = torch.sum(B, dim=-2) # [m]
68
+ self.b_r = self.b_r + b_r # [n]
69
+
70
+ rowscale = self.a_r + self.b_r # [n]
71
+ colscale = b_c + B.mT @ self.Ainv @ self.b_r # [m]
72
+ scale = (rowscale[:, None] * colscale) ** -0.5 # [n x m]
73
+ return (B * scale).mT # [m x n]
74
+
75
+ def transform(self, features: torch.Tensor = None) -> torch.Tensor:
76
+ rowscale = self.a_r + self.b_r # [n]
77
+ if features is None:
78
+ B = self.A # [n x n]
79
+ colscale = rowscale # [n]
80
+ else:
81
+ B = affinity_from_features(
82
+ self.anchor_features, # [n x d]
83
+ features, # [m x d]
84
+ affinity_focal_gamma=self.affinity_focal_gamma,
85
+ distance=self.distance,
86
+ fill_diagonal=False,
87
+ ) # [n x m]
88
+ b_c = torch.sum(B, dim=-2) # [m]
89
+ colscale = b_c + B.mT @ self.Ainv @ self.b_r # [m]
90
+ scale = (rowscale[:, None] * colscale) ** -0.5 # [n x m]
91
+ return (B * scale).mT # [m x n]
92
+
93
+
94
+ class NewNCUT(OnlineNystrom):
95
+ """Nystrom Normalized Cut for large scale graph."""
96
+
97
+ def __init__(
98
+ self,
99
+ num_eig: int = 100,
100
+ affinity_focal_gamma: float = 1.0,
101
+ num_sample: int = 10000,
102
+ sample_method: Literal["farthest", "random"] = "farthest",
103
+ distance: DistanceOptions = "cosine",
104
+ eig_solver: EigSolverOptions = "svd_lowrank",
105
+ normalize_features: bool = None,
106
+ device: str = None,
107
+ move_output_to_cpu: bool = False,
108
+ matmul_chunk_size: int = 8096,
109
+ ):
110
+ """
111
+ Args:
112
+ num_eig (int): number of top eigenvectors to return
113
+ affinity_focal_gamma (float): affinity matrix temperature, lower t reduce the not-so-connected edge weights,
114
+ smaller t result in more sharp eigenvectors.
115
+ num_sample (int): number of samples for Nystrom-like approximation,
116
+ reduce only if memory is not enough, increase for better approximation
117
+ sample_method (str): subgraph sampling, ['farthest', 'random'].
118
+ farthest point sampling is recommended for better Nystrom-approximation accuracy
119
+ distance (str): distance metric for affinity matrix, ['cosine', 'euclidean', 'rbf'].
120
+ eig_solver (str): eigen decompose solver, ['svd_lowrank', 'lobpcg', 'svd', 'eigh'].
121
+ normalize_features (bool): normalize input features before computing affinity matrix,
122
+ default 'None' is True for cosine distance, False for euclidean distance and rbf
123
+ device (str): device to use for eigen computation,
124
+ move to GPU to speeds up a bit (~5x faster)
125
+ move_output_to_cpu (bool): move output to CPU, set to True if you have memory issue
126
+ matmul_chunk_size (int): chunk size for large-scale matrix multiplication
127
+ """
128
+ OnlineNystrom.__init__(
129
+ self,
130
+ n_components=num_eig,
131
+ kernel=LaplacianKernel(affinity_focal_gamma, distance, eig_solver),
132
+ eig_solver=eig_solver,
133
+ chunk_size=matmul_chunk_size,
134
+ )
135
+ self.num_sample = num_sample
136
+ self.sample_method = sample_method
137
+ self.distance = distance
138
+ self.normalize_features = normalize_features
139
+ if self.normalize_features is None:
140
+ if distance in ["cosine"]:
141
+ self.normalize_features = True
142
+ if distance in ["euclidean", "rbf"]:
143
+ self.normalize_features = False
144
+
145
+ self.device = device
146
+ self.move_output_to_cpu = move_output_to_cpu
147
+ self.matmul_chunk_size = matmul_chunk_size
148
+
149
+ def _fit_helper(
150
+ self,
151
+ features: torch.Tensor,
152
+ precomputed_sampled_indices: torch.Tensor,
153
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
154
+ # move subgraph gpu to speed up
155
+ original_device = features.device
156
+ device = original_device if self.device is None else self.device
157
+
158
+ _n = features.shape[0]
159
+ if self.num_sample >= _n:
160
+ logging.info(
161
+ f"NCUT nystrom num_sample is larger than number of input samples, nyström approximation is not needed, setting num_sample={_n}"
162
+ )
163
+ self.num_sample = _n
164
+
165
+ # check if features dimension greater than num_eig
166
+ if self.eig_solver in ["svd_lowrank", "lobpcg"]:
167
+ assert (
168
+ _n >= self.n_components * 2
169
+ ), "number of nodes should be greater than 2*num_eig"
170
+ elif self.eig_solver in ["svd", "eigh"]:
171
+ assert (
172
+ _n >= self.n_components
173
+ ), "number of nodes should be greater than num_eig"
174
+
175
+ assert self.distance in ["cosine", "euclidean", "rbf"], "distance should be 'cosine', 'euclidean', 'rbf'"
176
+
177
+ if self.normalize_features:
178
+ # features need to be normalized for affinity matrix computation (cosine distance)
179
+ features = torch.nn.functional.normalize(features, dim=-1)
180
+
181
+ if precomputed_sampled_indices is not None:
182
+ sampled_indices = precomputed_sampled_indices
183
+ else:
184
+ sampled_indices = run_subgraph_sampling(
185
+ features,
186
+ num_sample=self.num_sample,
187
+ sample_method=self.sample_method,
188
+ )
189
+ sampled_features = features[sampled_indices].to(device)
190
+ OnlineNystrom.fit(self, sampled_features)
191
+
192
+ _n_not_sampled = _n - len(sampled_features)
193
+ if _n_not_sampled > 0:
194
+ unsampled_indices = torch.full((_n,), True).scatter(0, sampled_indices, False)
195
+ unsampled_features = features[unsampled_indices].to(device)
196
+ V_unsampled, _ = OnlineNystrom.update(self, unsampled_features)
197
+ else:
198
+ unsampled_indices = V_unsampled = None
199
+ return unsampled_indices, V_unsampled
200
+
201
+ def fit(
202
+ self,
203
+ features: torch.Tensor,
204
+ precomputed_sampled_indices: torch.Tensor = None,
205
+ ):
206
+ """Fit Nystrom Normalized Cut on the input features.
207
+ Args:
208
+ features (torch.Tensor): input features, shape (n_samples, n_features)
209
+ precomputed_sampled_indices (torch.Tensor): precomputed sampled indices, shape (num_sample,)
210
+ override the sample_method, if not None
211
+ Returns:
212
+ (NCUT): self
213
+ """
214
+ NewNCUT._fit_helper(self, features, precomputed_sampled_indices)
215
+ return self
216
+
217
+ def fit_transform(
218
+ self,
219
+ features: torch.Tensor,
220
+ precomputed_sampled_indices: torch.Tensor = None,
221
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
222
+ """
223
+ Args:
224
+ features (torch.Tensor): input features, shape (n_samples, n_features)
225
+ precomputed_sampled_indices (torch.Tensor): precomputed sampled indices, shape (num_sample,)
226
+ override the sample_method, if not None
227
+
228
+ Returns:
229
+ (torch.Tensor): eigen_vectors, shape (n_samples, num_eig)
230
+ (torch.Tensor): eigen_values, sorted in descending order, shape (num_eig,)
231
+ """
232
+ unsampled_indices, V_unsampled = NewNCUT._fit_helper(self, features, precomputed_sampled_indices)
233
+ V_sampled, L = OnlineNystrom.transform(self)
234
+
235
+ if unsampled_indices is not None:
236
+ V = torch.zeros((len(unsampled_indices), self.n_components))
237
+ V[~unsampled_indices] = V_sampled
238
+ V[unsampled_indices] = V_unsampled
239
+ else:
240
+ V = V_sampled
241
+ return V, L
@@ -0,0 +1,170 @@
1
+ from typing import Literal, Tuple
2
+
3
+ import torch
4
+
5
+
6
+ EigSolverOptions = Literal["svd_lowrank", "lobpcg", "svd", "eigh"]
7
+
8
+
9
+ class OnlineKernel:
10
+ def fit(self, features: torch.Tensor) -> None: # [n x d]
11
+ raise NotImplementedError()
12
+
13
+ def update(self, features: torch.Tensor) -> torch.Tensor: # [m x d] -> [m x n]
14
+ raise NotImplementedError()
15
+
16
+ def transform(self, features: torch.Tensor = None) -> torch.Tensor: # [m x d] -> [m x n]
17
+ raise NotImplementedError()
18
+
19
+
20
+ class OnlineNystrom:
21
+ def __init__(
22
+ self,
23
+ n_components: int,
24
+ kernel: OnlineKernel,
25
+ eig_solver: EigSolverOptions,
26
+ chunk_size: int = 8192,
27
+ ):
28
+ """
29
+ Args:
30
+ n_components (int): number of top eigenvectors to return
31
+ kernel (OnlineKernel): Online kernel that computes pairwise matrix entries from input features and allows updates
32
+ eig_solver (str): eigen decompose solver, ['svd_lowrank', 'lobpcg', 'svd', 'eigh'].
33
+ """
34
+ self.n_components: int = n_components
35
+ self.kernel: OnlineKernel = kernel
36
+ self.eig_solver: EigSolverOptions = eig_solver
37
+ self.inverse_approximation_dim: int = None
38
+
39
+ self.chunk_size = chunk_size
40
+
41
+ # Anchor matrices
42
+ self.anchor_features: torch.Tensor = None # [n x d]
43
+ self.A: torch.Tensor = None # [n x n]
44
+ self.Ahinv: torch.Tensor = None # [n x n]
45
+ self.Ahinv_UL: torch.Tensor = None # [n x indirect_pca_dim]
46
+ self.Ahinv_VT: torch.Tensor = None # [indirect_pca_dim x n]
47
+
48
+ # Updated matrices
49
+ self.S: torch.Tensor = None # [n x n]
50
+ self.transform_matrix: torch.Tensor = None # [n x n_components]
51
+ self.LS: torch.Tensor = None # [n]
52
+
53
+ def fit(self, features: torch.Tensor):
54
+ OnlineNystrom.fit_transform(self, features)
55
+ return self
56
+
57
+ def fit_transform(self, features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
58
+ self.anchor_features = features
59
+
60
+ self.kernel.fit(self.anchor_features)
61
+ self.A = self.S = self.kernel.transform() # [n x n]
62
+
63
+ self.inverse_approximation_dim = max(self.n_components, features.shape[-1]) + 1
64
+ U, L = solve_eig(
65
+ self.A,
66
+ num_eig=self.inverse_approximation_dim,
67
+ eig_solver=self.eig_solver,
68
+ ) # [n x (? + 1)], [? + 1]
69
+ self.Ahinv_UL = U * (L ** -0.5) # [n x (? + 1)]
70
+ self.Ahinv_VT = U.mT # [(? + 1) x n]
71
+ self.Ahinv = self.Ahinv_UL @ self.Ahinv_VT # [n x n]
72
+
73
+ self.transform_matrix = (U / L)[:, :self.n_components] # [n x n_components]
74
+ self.LS = L[:self.n_components] # [n_components]
75
+ return U[:, :self.n_components], L[:self.n_components] # [n x n_components], [n_components]
76
+
77
+ def update(self, features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
78
+ n_chunks = -(-len(features) // self.chunk_size)
79
+ if n_chunks > 1:
80
+ """ Chunked version """
81
+ chunks = torch.chunk(features, n_chunks, dim=0)
82
+ for chunk in chunks:
83
+ self.kernel.update(chunk)
84
+
85
+ compressed_BBT = torch.zeros((self.inverse_approximation_dim, self.inverse_approximation_dim)) # [(? + 1) x (? + 1))]
86
+ for i, chunk in enumerate(chunks):
87
+ _B = self.kernel.transform(chunk).mT # [n x _m]
88
+ _compressed_B = self.Ahinv_VT @ _B # [(? + 1) x _m]
89
+ compressed_BBT = compressed_BBT + _compressed_B @ _compressed_B.mT # [(? + 1) x (? + 1)]
90
+ self.S = self.S + self.Ahinv_UL @ compressed_BBT @ self.Ahinv_UL.mT # [n x n]
91
+ US, self.LS = solve_eig(self.S, self.n_components, self.eig_solver) # [n x n_components], [n_components]
92
+ self.transform_matrix = self.Ahinv @ US * (self.LS ** -0.5) # [n x n_components]
93
+
94
+ VS = []
95
+ for chunk in chunks:
96
+ VS.append(self.kernel.transform(chunk) @ self.transform_matrix) # [_m x n_components]
97
+ VS = torch.cat(VS, dim=0)
98
+ return VS, self.LS # [m x n_components], [n_components]
99
+ else:
100
+ """ Unchunked version """
101
+ B = self.kernel.update(features).mT # [n x m]
102
+ compressed_B = self.Ahinv_VT @ B # [indirect_pca_dim x m]
103
+
104
+ self.S = self.S + self.Ahinv_UL @ (compressed_B @ compressed_B.mT) @ self.Ahinv_UL.mT # [n x n]
105
+ US, self.LS = solve_eig(self.S, self.n_components, self.eig_solver) # [n x n_components], [n_components]
106
+ self.transform_matrix = self.Ahinv @ US * (self.LS ** -0.5) # [n x n_components]
107
+
108
+ return B.mT @ self.transform_matrix, self.LS # [m x n_components], [n_components]
109
+
110
+ def transform(self, features: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
111
+ if features is None:
112
+ VS = self.A @ self.transform_matrix # [n x n_components]
113
+ else:
114
+ n_chunks = -(-len(features) // self.chunk_size)
115
+ if n_chunks > 1:
116
+ """ Chunked version """
117
+ chunks = torch.chunk(features, n_chunks, dim=0)
118
+ VS = []
119
+ for chunk in chunks:
120
+ VS.append(self.kernel.transform(chunk) @ self.transform_matrix) # [_m x n_components]
121
+ VS = torch.cat(VS, dim=0)
122
+ else:
123
+ """ Unchunked version """
124
+ VS = self.kernel.transform(features) @ self.transform_matrix # [m x n_components]
125
+ return VS, self.LS # [m x n_components], [n_components]
126
+
127
+
128
+ def solve_eig(
129
+ A: torch.Tensor,
130
+ num_eig: int,
131
+ eig_solver: Literal["svd_lowrank", "lobpcg", "svd", "eigh"],
132
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
133
+ """PyTorch implementation of Eigensolver cut without Nystrom-like approximation.
134
+
135
+ Args:
136
+ A (torch.Tensor): input matrix, shape (n_samples, n_samples)
137
+ num_eig (int): number of eigenvectors to return
138
+ eig_solver (str): eigen decompose solver, ['svd_lowrank', 'lobpcg', 'svd', 'eigh']
139
+
140
+ Returns:
141
+ (torch.Tensor): eigenvectors corresponding to the eigenvalues, shape (n_samples, num_eig)
142
+ (torch.Tensor): eigenvalues of the eigenvectors, sorted in descending order
143
+ """
144
+ # compute eigenvectors
145
+ if eig_solver == "svd_lowrank": # default
146
+ # only top q eigenvectors, fastest
147
+ eigen_vector, eigen_value, _ = torch.svd_lowrank(A, q=num_eig)
148
+ elif eig_solver == "lobpcg":
149
+ # only top k eigenvectors, fast
150
+ eigen_value, eigen_vector = torch.lobpcg(A, k=num_eig)
151
+ elif eig_solver == "svd":
152
+ # all eigenvectors, slow
153
+ eigen_vector, eigen_value, _ = torch.svd(A)
154
+ elif eig_solver == "eigh":
155
+ # all eigenvectors, slow
156
+ eigen_value, eigen_vector = torch.linalg.eigh(A)
157
+ else:
158
+ raise ValueError(
159
+ "eigen_solver should be 'lobpcg', 'svd_lowrank', 'svd' or 'eigh'"
160
+ )
161
+
162
+ # sort eigenvectors by eigenvalues, take top (descending order)
163
+ eigen_value = eigen_value.real
164
+ eigen_vector = eigen_vector.real
165
+ eigen_value, indices = torch.topk(eigen_value, k=num_eig, dim=0)
166
+ eigen_vector = eigen_vector[:, indices]
167
+
168
+ # correct the random rotation (flipping sign) of eigenvectors
169
+ eigen_vector = eigen_vector * torch.sum(eigen_vector, dim=0).sign()
170
+ return eigen_vector, eigen_value