nystrom-ncut 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,241 @@
1
+ import logging
2
+ from typing import Literal, Tuple
3
+
4
+ import torch
5
+
6
+ from .nystrom import (
7
+ EigSolverOptions,
8
+ OnlineKernel,
9
+ OnlineNystrom,
10
+ solve_eig,
11
+ )
12
+ from .propagation_utils import (
13
+ affinity_from_features,
14
+ run_subgraph_sampling,
15
+ )
16
+
17
+
18
+ DistanceOptions = Literal["cosine", "euclidean", "rbf"]
19
+
20
+
21
+ class LaplacianKernel(OnlineKernel):
22
+ def __init__(
23
+ self,
24
+ affinity_focal_gamma: float,
25
+ distance: DistanceOptions,
26
+ eig_solver: EigSolverOptions,
27
+ ):
28
+ self.affinity_focal_gamma = affinity_focal_gamma
29
+ self.distance: DistanceOptions = distance
30
+ self.eig_solver: EigSolverOptions = eig_solver
31
+
32
+ # Anchor matrices
33
+ self.anchor_features: torch.Tensor = None # [n x d]
34
+ self.A: torch.Tensor = None # [n x n]
35
+ self.Ainv: torch.Tensor = None # [n x n]
36
+
37
+ # Updated matrices
38
+ self.a_r: torch.Tensor = None # [n]
39
+ self.b_r: torch.Tensor = None # [n]
40
+
41
+ def fit(self, features: torch.Tensor) -> None:
42
+ self.anchor_features = features # [n x d]
43
+ self.A = affinity_from_features(
44
+ self.anchor_features, # [n x d]
45
+ affinity_focal_gamma=self.affinity_focal_gamma,
46
+ distance=self.distance,
47
+ fill_diagonal=False,
48
+ ) # [n x n]
49
+ U, L = solve_eig(
50
+ self.A,
51
+ num_eig=features.shape[-1] + 1,
52
+ eig_solver=self.eig_solver,
53
+ ) # [n x (d + 1)], [d + 1]
54
+ self.Ainv = U @ torch.diag(1 / L) @ U.mT # [n x n]
55
+ self.a_r = torch.sum(self.A, dim=-1) # [n]
56
+ self.b_r = torch.zeros_like(self.a_r) # [n]
57
+
58
+ def update(self, features: torch.Tensor) -> torch.Tensor:
59
+ B = affinity_from_features(
60
+ self.anchor_features, # [n x d]
61
+ features, # [m x d]
62
+ affinity_focal_gamma=self.affinity_focal_gamma,
63
+ distance=self.distance,
64
+ fill_diagonal=False,
65
+ ) # [n x m]
66
+ b_r = torch.sum(B, dim=-1) # [n]
67
+ b_c = torch.sum(B, dim=-2) # [m]
68
+ self.b_r = self.b_r + b_r # [n]
69
+
70
+ rowscale = self.a_r + self.b_r # [n]
71
+ colscale = b_c + B.mT @ self.Ainv @ self.b_r # [m]
72
+ scale = (rowscale[:, None] * colscale) ** -0.5 # [n x m]
73
+ return (B * scale).mT # [m x n]
74
+
75
+ def transform(self, features: torch.Tensor = None) -> torch.Tensor:
76
+ rowscale = self.a_r + self.b_r # [n]
77
+ if features is None:
78
+ B = self.A # [n x n]
79
+ colscale = rowscale # [n]
80
+ else:
81
+ B = affinity_from_features(
82
+ self.anchor_features, # [n x d]
83
+ features, # [m x d]
84
+ affinity_focal_gamma=self.affinity_focal_gamma,
85
+ distance=self.distance,
86
+ fill_diagonal=False,
87
+ ) # [n x m]
88
+ b_c = torch.sum(B, dim=-2) # [m]
89
+ colscale = b_c + B.mT @ self.Ainv @ self.b_r # [m]
90
+ scale = (rowscale[:, None] * colscale) ** -0.5 # [n x m]
91
+ return (B * scale).mT # [m x n]
92
+
93
+
94
+ class NewNCUT(OnlineNystrom):
95
+ """Nystrom Normalized Cut for large scale graph."""
96
+
97
+ def __init__(
98
+ self,
99
+ num_eig: int = 100,
100
+ affinity_focal_gamma: float = 1.0,
101
+ num_sample: int = 10000,
102
+ sample_method: Literal["farthest", "random"] = "farthest",
103
+ distance: DistanceOptions = "cosine",
104
+ eig_solver: EigSolverOptions = "svd_lowrank",
105
+ normalize_features: bool = None,
106
+ device: str = None,
107
+ move_output_to_cpu: bool = False,
108
+ matmul_chunk_size: int = 8096,
109
+ ):
110
+ """
111
+ Args:
112
+ num_eig (int): number of top eigenvectors to return
113
+ affinity_focal_gamma (float): affinity matrix temperature, lower t reduce the not-so-connected edge weights,
114
+ smaller t result in more sharp eigenvectors.
115
+ num_sample (int): number of samples for Nystrom-like approximation,
116
+ reduce only if memory is not enough, increase for better approximation
117
+ sample_method (str): subgraph sampling, ['farthest', 'random'].
118
+ farthest point sampling is recommended for better Nystrom-approximation accuracy
119
+ distance (str): distance metric for affinity matrix, ['cosine', 'euclidean', 'rbf'].
120
+ eig_solver (str): eigen decompose solver, ['svd_lowrank', 'lobpcg', 'svd', 'eigh'].
121
+ normalize_features (bool): normalize input features before computing affinity matrix,
122
+ default 'None' is True for cosine distance, False for euclidean distance and rbf
123
+ device (str): device to use for eigen computation,
124
+ move to GPU to speeds up a bit (~5x faster)
125
+ move_output_to_cpu (bool): move output to CPU, set to True if you have memory issue
126
+ matmul_chunk_size (int): chunk size for large-scale matrix multiplication
127
+ """
128
+ OnlineNystrom.__init__(
129
+ self,
130
+ n_components=num_eig,
131
+ kernel=LaplacianKernel(affinity_focal_gamma, distance, eig_solver),
132
+ eig_solver=eig_solver,
133
+ chunk_size=matmul_chunk_size,
134
+ )
135
+ self.num_sample = num_sample
136
+ self.sample_method = sample_method
137
+ self.distance = distance
138
+ self.normalize_features = normalize_features
139
+ if self.normalize_features is None:
140
+ if distance in ["cosine"]:
141
+ self.normalize_features = True
142
+ if distance in ["euclidean", "rbf"]:
143
+ self.normalize_features = False
144
+
145
+ self.device = device
146
+ self.move_output_to_cpu = move_output_to_cpu
147
+ self.matmul_chunk_size = matmul_chunk_size
148
+
149
+ def _fit_helper(
150
+ self,
151
+ features: torch.Tensor,
152
+ precomputed_sampled_indices: torch.Tensor,
153
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
154
+ # move subgraph gpu to speed up
155
+ original_device = features.device
156
+ device = original_device if self.device is None else self.device
157
+
158
+ _n = features.shape[0]
159
+ if self.num_sample >= _n:
160
+ logging.info(
161
+ f"NCUT nystrom num_sample is larger than number of input samples, nyström approximation is not needed, setting num_sample={_n}"
162
+ )
163
+ self.num_sample = _n
164
+
165
+ # check if features dimension greater than num_eig
166
+ if self.eig_solver in ["svd_lowrank", "lobpcg"]:
167
+ assert (
168
+ _n >= self.n_components * 2
169
+ ), "number of nodes should be greater than 2*num_eig"
170
+ elif self.eig_solver in ["svd", "eigh"]:
171
+ assert (
172
+ _n >= self.n_components
173
+ ), "number of nodes should be greater than num_eig"
174
+
175
+ assert self.distance in ["cosine", "euclidean", "rbf"], "distance should be 'cosine', 'euclidean', 'rbf'"
176
+
177
+ if self.normalize_features:
178
+ # features need to be normalized for affinity matrix computation (cosine distance)
179
+ features = torch.nn.functional.normalize(features, dim=-1)
180
+
181
+ if precomputed_sampled_indices is not None:
182
+ sampled_indices = precomputed_sampled_indices
183
+ else:
184
+ sampled_indices = run_subgraph_sampling(
185
+ features,
186
+ num_sample=self.num_sample,
187
+ sample_method=self.sample_method,
188
+ )
189
+ sampled_features = features[sampled_indices].to(device)
190
+ OnlineNystrom.fit(self, sampled_features)
191
+
192
+ _n_not_sampled = _n - len(sampled_features)
193
+ if _n_not_sampled > 0:
194
+ unsampled_indices = torch.full((_n,), True).scatter(0, sampled_indices, False)
195
+ unsampled_features = features[unsampled_indices].to(device)
196
+ V_unsampled, _ = OnlineNystrom.update(self, unsampled_features)
197
+ else:
198
+ unsampled_indices = V_unsampled = None
199
+ return unsampled_indices, V_unsampled
200
+
201
+ def fit(
202
+ self,
203
+ features: torch.Tensor,
204
+ precomputed_sampled_indices: torch.Tensor = None,
205
+ ):
206
+ """Fit Nystrom Normalized Cut on the input features.
207
+ Args:
208
+ features (torch.Tensor): input features, shape (n_samples, n_features)
209
+ precomputed_sampled_indices (torch.Tensor): precomputed sampled indices, shape (num_sample,)
210
+ override the sample_method, if not None
211
+ Returns:
212
+ (NCUT): self
213
+ """
214
+ NewNCUT._fit_helper(self, features, precomputed_sampled_indices)
215
+ return self
216
+
217
+ def fit_transform(
218
+ self,
219
+ features: torch.Tensor,
220
+ precomputed_sampled_indices: torch.Tensor = None,
221
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
222
+ """
223
+ Args:
224
+ features (torch.Tensor): input features, shape (n_samples, n_features)
225
+ precomputed_sampled_indices (torch.Tensor): precomputed sampled indices, shape (num_sample,)
226
+ override the sample_method, if not None
227
+
228
+ Returns:
229
+ (torch.Tensor): eigen_vectors, shape (n_samples, num_eig)
230
+ (torch.Tensor): eigen_values, sorted in descending order, shape (num_eig,)
231
+ """
232
+ unsampled_indices, V_unsampled = NewNCUT._fit_helper(self, features, precomputed_sampled_indices)
233
+ V_sampled, L = OnlineNystrom.transform(self)
234
+
235
+ if unsampled_indices is not None:
236
+ V = torch.zeros((len(unsampled_indices), self.n_components))
237
+ V[~unsampled_indices] = V_sampled
238
+ V[unsampled_indices] = V_unsampled
239
+ else:
240
+ V = V_sampled
241
+ return V, L
@@ -0,0 +1,170 @@
1
+ from typing import Literal, Tuple
2
+
3
+ import torch
4
+
5
+
6
+ EigSolverOptions = Literal["svd_lowrank", "lobpcg", "svd", "eigh"]
7
+
8
+
9
+ class OnlineKernel:
10
+ def fit(self, features: torch.Tensor) -> None: # [n x d]
11
+ raise NotImplementedError()
12
+
13
+ def update(self, features: torch.Tensor) -> torch.Tensor: # [m x d] -> [m x n]
14
+ raise NotImplementedError()
15
+
16
+ def transform(self, features: torch.Tensor = None) -> torch.Tensor: # [m x d] -> [m x n]
17
+ raise NotImplementedError()
18
+
19
+
20
+ class OnlineNystrom:
21
+ def __init__(
22
+ self,
23
+ n_components: int,
24
+ kernel: OnlineKernel,
25
+ eig_solver: EigSolverOptions,
26
+ chunk_size: int = 8192,
27
+ ):
28
+ """
29
+ Args:
30
+ n_components (int): number of top eigenvectors to return
31
+ kernel (OnlineKernel): Online kernel that computes pairwise matrix entries from input features and allows updates
32
+ eig_solver (str): eigen decompose solver, ['svd_lowrank', 'lobpcg', 'svd', 'eigh'].
33
+ """
34
+ self.n_components: int = n_components
35
+ self.kernel: OnlineKernel = kernel
36
+ self.eig_solver: EigSolverOptions = eig_solver
37
+ self.inverse_approximation_dim: int = None
38
+
39
+ self.chunk_size = chunk_size
40
+
41
+ # Anchor matrices
42
+ self.anchor_features: torch.Tensor = None # [n x d]
43
+ self.A: torch.Tensor = None # [n x n]
44
+ self.Ahinv: torch.Tensor = None # [n x n]
45
+ self.Ahinv_UL: torch.Tensor = None # [n x indirect_pca_dim]
46
+ self.Ahinv_VT: torch.Tensor = None # [indirect_pca_dim x n]
47
+
48
+ # Updated matrices
49
+ self.S: torch.Tensor = None # [n x n]
50
+ self.transform_matrix: torch.Tensor = None # [n x n_components]
51
+ self.LS: torch.Tensor = None # [n]
52
+
53
+ def fit(self, features: torch.Tensor):
54
+ OnlineNystrom.fit_transform(self, features)
55
+ return self
56
+
57
+ def fit_transform(self, features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
58
+ self.anchor_features = features
59
+
60
+ self.kernel.fit(self.anchor_features)
61
+ self.A = self.S = self.kernel.transform() # [n x n]
62
+
63
+ self.inverse_approximation_dim = max(self.n_components, features.shape[-1]) + 1
64
+ U, L = solve_eig(
65
+ self.A,
66
+ num_eig=self.inverse_approximation_dim,
67
+ eig_solver=self.eig_solver,
68
+ ) # [n x (? + 1)], [? + 1]
69
+ self.Ahinv_UL = U * (L ** -0.5) # [n x (? + 1)]
70
+ self.Ahinv_VT = U.mT # [(? + 1) x n]
71
+ self.Ahinv = self.Ahinv_UL @ self.Ahinv_VT # [n x n]
72
+
73
+ self.transform_matrix = (U / L)[:, :self.n_components] # [n x n_components]
74
+ self.LS = L[:self.n_components] # [n_components]
75
+ return U[:, :self.n_components], L[:self.n_components] # [n x n_components], [n_components]
76
+
77
+ def update(self, features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
78
+ n_chunks = -(-len(features) // self.chunk_size)
79
+ if n_chunks > 1:
80
+ """ Chunked version """
81
+ chunks = torch.chunk(features, n_chunks, dim=0)
82
+ for chunk in chunks:
83
+ self.kernel.update(chunk)
84
+
85
+ compressed_BBT = torch.zeros((self.inverse_approximation_dim, self.inverse_approximation_dim)) # [(? + 1) x (? + 1))]
86
+ for i, chunk in enumerate(chunks):
87
+ _B = self.kernel.transform(chunk).mT # [n x _m]
88
+ _compressed_B = self.Ahinv_VT @ _B # [(? + 1) x _m]
89
+ compressed_BBT = compressed_BBT + _compressed_B @ _compressed_B.mT # [(? + 1) x (? + 1)]
90
+ self.S = self.S + self.Ahinv_UL @ compressed_BBT @ self.Ahinv_UL.mT # [n x n]
91
+ US, self.LS = solve_eig(self.S, self.n_components, self.eig_solver) # [n x n_components], [n_components]
92
+ self.transform_matrix = self.Ahinv @ US * (self.LS ** -0.5) # [n x n_components]
93
+
94
+ VS = []
95
+ for chunk in chunks:
96
+ VS.append(self.kernel.transform(chunk) @ self.transform_matrix) # [_m x n_components]
97
+ VS = torch.cat(VS, dim=0)
98
+ return VS, self.LS # [m x n_components], [n_components]
99
+ else:
100
+ """ Unchunked version """
101
+ B = self.kernel.update(features).mT # [n x m]
102
+ compressed_B = self.Ahinv_VT @ B # [indirect_pca_dim x m]
103
+
104
+ self.S = self.S + self.Ahinv_UL @ (compressed_B @ compressed_B.mT) @ self.Ahinv_UL.mT # [n x n]
105
+ US, self.LS = solve_eig(self.S, self.n_components, self.eig_solver) # [n x n_components], [n_components]
106
+ self.transform_matrix = self.Ahinv @ US * (self.LS ** -0.5) # [n x n_components]
107
+
108
+ return B.mT @ self.transform_matrix, self.LS # [m x n_components], [n_components]
109
+
110
+ def transform(self, features: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
111
+ if features is None:
112
+ VS = self.A @ self.transform_matrix # [n x n_components]
113
+ else:
114
+ n_chunks = -(-len(features) // self.chunk_size)
115
+ if n_chunks > 1:
116
+ """ Chunked version """
117
+ chunks = torch.chunk(features, n_chunks, dim=0)
118
+ VS = []
119
+ for chunk in chunks:
120
+ VS.append(self.kernel.transform(chunk) @ self.transform_matrix) # [_m x n_components]
121
+ VS = torch.cat(VS, dim=0)
122
+ else:
123
+ """ Unchunked version """
124
+ VS = self.kernel.transform(features) @ self.transform_matrix # [m x n_components]
125
+ return VS, self.LS # [m x n_components], [n_components]
126
+
127
+
128
+ def solve_eig(
129
+ A: torch.Tensor,
130
+ num_eig: int,
131
+ eig_solver: Literal["svd_lowrank", "lobpcg", "svd", "eigh"],
132
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
133
+ """PyTorch implementation of Eigensolver cut without Nystrom-like approximation.
134
+
135
+ Args:
136
+ A (torch.Tensor): input matrix, shape (n_samples, n_samples)
137
+ num_eig (int): number of eigenvectors to return
138
+ eig_solver (str): eigen decompose solver, ['svd_lowrank', 'lobpcg', 'svd', 'eigh']
139
+
140
+ Returns:
141
+ (torch.Tensor): eigenvectors corresponding to the eigenvalues, shape (n_samples, num_eig)
142
+ (torch.Tensor): eigenvalues of the eigenvectors, sorted in descending order
143
+ """
144
+ # compute eigenvectors
145
+ if eig_solver == "svd_lowrank": # default
146
+ # only top q eigenvectors, fastest
147
+ eigen_vector, eigen_value, _ = torch.svd_lowrank(A, q=num_eig)
148
+ elif eig_solver == "lobpcg":
149
+ # only top k eigenvectors, fast
150
+ eigen_value, eigen_vector = torch.lobpcg(A, k=num_eig)
151
+ elif eig_solver == "svd":
152
+ # all eigenvectors, slow
153
+ eigen_vector, eigen_value, _ = torch.svd(A)
154
+ elif eig_solver == "eigh":
155
+ # all eigenvectors, slow
156
+ eigen_value, eigen_vector = torch.linalg.eigh(A)
157
+ else:
158
+ raise ValueError(
159
+ "eigen_solver should be 'lobpcg', 'svd_lowrank', 'svd' or 'eigh'"
160
+ )
161
+
162
+ # sort eigenvectors by eigenvalues, take top (descending order)
163
+ eigen_value = eigen_value.real
164
+ eigen_vector = eigen_vector.real
165
+ eigen_value, indices = torch.topk(eigen_value, k=num_eig, dim=0)
166
+ eigen_vector = eigen_vector[:, indices]
167
+
168
+ # correct the random rotation (flipping sign) of eigenvectors
169
+ eigen_vector = eigen_vector * torch.sum(eigen_vector, dim=0).sign()
170
+ return eigen_vector, eigen_value