nystrom-ncut 0.0.6__tar.gz → 0.0.7__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: nystrom_ncut
3
- Version: 0.0.6
3
+ Version: 0.0.7
4
4
  Summary: Normalized Cut and Nyström Approximation
5
5
  Author-email: Huzheng Yang <huze.yann@gmail.com>, Wentinn Liao <wentinn.liao@gmail.com>
6
6
  Project-URL: Documentation, https://github.com/JophiArcana/Nystrom-NCUT/
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "nystrom_ncut"
7
- version = "0.0.6"
7
+ version = "0.0.7"
8
8
  authors = [
9
9
  { name = "Huzheng Yang", email = "huze.yann@gmail.com" },
10
10
  { name = "Wentinn Liao", email = "wentinn.liao@gmail.com" },
@@ -4,8 +4,8 @@ from .ncut_pytorch import (
4
4
  )
5
5
  from .propagation_utils import (
6
6
  affinity_from_features,
7
- propagate_eigenvectors,
8
- propagate_knn,
7
+ extrapolate_knn_with_subsampling,
8
+ extrapolate_knn,
9
9
  quantile_normalize,
10
10
  )
11
11
  from .visualize_utils import (
@@ -17,6 +17,5 @@ from .visualize_utils import (
17
17
  rgb_from_cosine_tsne_3d,
18
18
  rotate_rgb_cube,
19
19
  convert_to_lab_color,
20
- propagate_rgb_color,
21
20
  get_mask,
22
21
  )
@@ -1,10 +1,14 @@
1
- from typing import Any
1
+ from typing import Any, Literal
2
2
 
3
3
  import numpy as np
4
4
  import torch
5
5
  import torch.nn.functional as Fn
6
6
 
7
7
 
8
+ DistanceOptions = Literal["cosine", "euclidean", "rbf"]
9
+ SampleOptions = Literal["farthest", "random"]
10
+
11
+
8
12
  def ceildiv(a: int, b: int) -> int:
9
13
  return -(-a // b)
10
14
 
@@ -4,6 +4,10 @@ from typing import Literal, Tuple
4
4
  import torch
5
5
  import torch.nn.functional as Fn
6
6
 
7
+ from .common import (
8
+ DistanceOptions,
9
+ SampleOptions,
10
+ )
7
11
  from .nystrom import (
8
12
  EigSolverOptions,
9
13
  OnlineKernel,
@@ -16,9 +20,6 @@ from .propagation_utils import (
16
20
  )
17
21
 
18
22
 
19
- DistanceOptions = Literal["cosine", "euclidean", "rbf"]
20
-
21
-
22
23
  class LaplacianKernel(OnlineKernel):
23
24
  def __init__(
24
25
  self,
@@ -46,9 +47,10 @@ class LaplacianKernel(OnlineKernel):
46
47
  affinity_focal_gamma=self.affinity_focal_gamma,
47
48
  distance=self.distance,
48
49
  ) # [n x n]
50
+ d = features.shape[-1]
49
51
  U, L = solve_eig(
50
52
  self.A,
51
- num_eig=features.shape[-1] + 1,
53
+ num_eig=d + 1, # d * (d + 3) // 2 + 1,
52
54
  eig_solver=self.eig_solver,
53
55
  ) # [n x (d + 1)], [d + 1]
54
56
  self.Ainv = U @ torch.diag(1 / L) @ U.mT # [n x n]
@@ -97,11 +99,10 @@ class NCUT(OnlineNystrom):
97
99
  n_components: int = 100,
98
100
  affinity_focal_gamma: float = 1.0,
99
101
  num_sample: int = 10000,
100
- sample_method: Literal["farthest", "random"] = "farthest",
102
+ sample_method: SampleOptions = "farthest",
101
103
  distance: DistanceOptions = "cosine",
102
104
  eig_solver: EigSolverOptions = "svd_lowrank",
103
105
  normalize_features: bool = None,
104
- move_output_to_cpu: bool = False,
105
106
  chunk_size: int = 8192,
106
107
  ):
107
108
  """
@@ -117,7 +118,6 @@ class NCUT(OnlineNystrom):
117
118
  eig_solver (str): eigen decompose solver, ['svd_lowrank', 'lobpcg', 'svd', 'eigh'].
118
119
  normalize_features (bool): normalize input features before computing affinity matrix,
119
120
  default 'None' is True for cosine distance, False for euclidean distance and rbf
120
- move_output_to_cpu (bool): move output to CPU, set to True if you have memory issue
121
121
  chunk_size (int): chunk size for large-scale matrix multiplication
122
122
  """
123
123
  OnlineNystrom.__init__(
@@ -127,18 +127,18 @@ class NCUT(OnlineNystrom):
127
127
  eig_solver=eig_solver,
128
128
  chunk_size=chunk_size,
129
129
  )
130
- self.num_sample = num_sample
131
- self.sample_method = sample_method
132
- self.distance = distance
133
- self.normalize_features = normalize_features
130
+ self.num_sample: int = num_sample
131
+ self.sample_method: SampleOptions = sample_method
132
+ self.anchor_indices: torch.Tensor = None
133
+ self.distance: DistanceOptions = distance
134
+ self.normalize_features: bool = normalize_features
134
135
  if self.normalize_features is None:
135
136
  if distance in ["cosine"]:
136
137
  self.normalize_features = True
137
138
  if distance in ["euclidean", "rbf"]:
138
139
  self.normalize_features = False
139
140
 
140
- self.move_output_to_cpu = move_output_to_cpu
141
- self.chunk_size = chunk_size
141
+ self.chunk_size: int = chunk_size
142
142
 
143
143
  def _fit_helper(
144
144
  self,
@@ -152,16 +152,6 @@ class NCUT(OnlineNystrom):
152
152
  )
153
153
  self.num_sample = _n
154
154
 
155
- # check if features dimension greater than num_eig
156
- if self.eig_solver in ["svd_lowrank", "lobpcg"]:
157
- assert (
158
- _n >= self.n_components * 2
159
- ), "number of nodes should be greater than 2*num_eig"
160
- elif self.eig_solver in ["svd", "eigh"]:
161
- assert (
162
- _n >= self.n_components
163
- ), "number of nodes should be greater than num_eig"
164
-
165
155
  assert self.distance in ["cosine", "euclidean", "rbf"], "distance should be 'cosine', 'euclidean', 'rbf'"
166
156
 
167
157
  if self.normalize_features:
@@ -169,20 +159,20 @@ class NCUT(OnlineNystrom):
169
159
  features = torch.nn.functional.normalize(features, dim=-1)
170
160
 
171
161
  if precomputed_sampled_indices is not None:
172
- sampled_indices = precomputed_sampled_indices
162
+ _sampled_indices = precomputed_sampled_indices
173
163
  else:
174
- sampled_indices = run_subgraph_sampling(
164
+ _sampled_indices = run_subgraph_sampling(
175
165
  features,
176
166
  self.num_sample,
177
167
  sample_method=self.sample_method,
178
168
  )
179
- sampled_indices = torch.sort(sampled_indices).values
180
- sampled_features = features[sampled_indices]
169
+ self.anchor_indices = torch.sort(_sampled_indices).values
170
+ sampled_features = features[self.anchor_indices]
181
171
  OnlineNystrom.fit(self, sampled_features)
182
172
 
183
173
  _n_not_sampled = _n - len(sampled_features)
184
174
  if _n_not_sampled > 0:
185
- unsampled_indices = torch.full((_n,), True, device=features.device).scatter_(0, sampled_indices, False)
175
+ unsampled_indices = torch.full((_n,), True, device=features.device).scatter_(0, self.anchor_indices, False)
186
176
  unsampled_features = features[unsampled_indices]
187
177
  V_unsampled, _ = OnlineNystrom.update(self, unsampled_features)
188
178
  else:
@@ -72,7 +72,7 @@ class OnlineNystrom:
72
72
  self.anchor_features = features
73
73
 
74
74
  self.kernel.fit(self.anchor_features)
75
- self.inverse_approximation_dim = max(self.n_components, features.shape[-1]) + 1
75
+ self.inverse_approximation_dim = max(self.n_components, features.shape[-1] + 1)
76
76
  U, L = self._update_to_kernel() # [n x (? + 1)], [? + 1]
77
77
 
78
78
  self.transform_matrix = (U / L)[:, :self.n_components] # [n x n_components]
@@ -135,7 +135,7 @@ class OnlineNystrom:
135
135
  def solve_eig(
136
136
  A: torch.Tensor,
137
137
  num_eig: int,
138
- eig_solver: Literal["svd_lowrank", "lobpcg", "svd", "eigh"],
138
+ eig_solver: EigSolverOptions,
139
139
  ) -> Tuple[torch.Tensor, torch.Tensor]:
140
140
  """PyTorch implementation of Eigensolver cut without Nystrom-like approximation.
141
141
 
@@ -3,9 +3,14 @@ from typing import Literal
3
3
 
4
4
  import numpy as np
5
5
  import torch
6
- import torch.nn.functional as F
6
+ import torch.nn.functional as Fn
7
7
 
8
- from .common import ceildiv, lazy_normalize
8
+ from .common import (
9
+ DistanceOptions,
10
+ SampleOptions,
11
+ ceildiv,
12
+ lazy_normalize,
13
+ )
9
14
 
10
15
 
11
16
  @torch.no_grad()
@@ -13,7 +18,7 @@ def run_subgraph_sampling(
13
18
  features: torch.Tensor,
14
19
  num_sample: int,
15
20
  max_draw: int = 1000000,
16
- sample_method: Literal["farthest", "random"] = "farthest",
21
+ sample_method: SampleOptions = "farthest",
17
22
  ):
18
23
  if num_sample >= features.shape[0]:
19
24
  # if too many samples, use all samples and bypass Nystrom-like approximation
@@ -74,7 +79,7 @@ def farthest_point_sampling(
74
79
  def distance_from_features(
75
80
  features: torch.Tensor,
76
81
  features_B: torch.Tensor,
77
- distance: Literal["cosine", "euclidean", "rbf"],
82
+ distance: DistanceOptions,
78
83
  ):
79
84
  """Compute affinity matrix from input features.
80
85
  Args:
@@ -103,7 +108,7 @@ def affinity_from_features(
103
108
  features: torch.Tensor,
104
109
  features_B: torch.Tensor = None,
105
110
  affinity_focal_gamma: float = 1.0,
106
- distance: Literal["cosine", "euclidean", "rbf"] = "cosine",
111
+ distance: DistanceOptions = "cosine",
107
112
  ):
108
113
  """Compute affinity matrix from input features.
109
114
 
@@ -131,23 +136,23 @@ def affinity_from_features(
131
136
  return A
132
137
 
133
138
 
134
- def propagate_knn(
135
- subgraph_output: torch.Tensor,
136
- inp_features: torch.Tensor,
137
- subgraph_features: torch.Tensor,
138
- knn: int = 10,
139
- distance: Literal["cosine", "euclidean", "rbf"] = "cosine",
139
+ def extrapolate_knn(
140
+ anchor_features: torch.Tensor, # [n x d]
141
+ anchor_output: torch.Tensor, # [n x d']
142
+ extrapolation_features: torch.Tensor, # [m x d]
143
+ knn: int = 10, # k
144
+ distance: DistanceOptions = "cosine",
140
145
  affinity_focal_gamma: float = 1.0,
141
146
  chunk_size: int = 8192,
142
147
  device: str = None,
143
- move_output_to_cpu: bool = False,
144
- ):
148
+ move_output_to_cpu: bool = False
149
+ ) -> torch.Tensor: # [m x d']
145
150
  """A generic function to propagate new nodes using KNN.
146
151
 
147
152
  Args:
148
- subgraph_output (torch.Tensor): output from subgraph, shape (num_sample, D)
149
- inp_features (torch.Tensor): features from existing nodes, shape (new_num_samples, n_features)
150
- subgraph_features (torch.Tensor): features from subgraph, shape (num_sample, n_features)
153
+ anchor_features (torch.Tensor): features from subgraph, shape (num_sample, n_features)
154
+ anchor_output (torch.Tensor): output from subgraph, shape (num_sample, D)
155
+ extrapolation_features (torch.Tensor): features from existing nodes, shape (new_num_samples, n_features)
151
156
  knn (int): number of KNN to propagate eige nvectors
152
157
  distance (str): distance metric, 'cosine' (default) or 'euclidean', 'rbf'
153
158
  chunk_size (int): chunk size for matrix multiplication
@@ -159,121 +164,77 @@ def propagate_knn(
159
164
  >>> old_eigenvectors = torch.randn(3000, 20)
160
165
  >>> old_features = torch.randn(3000, 100)
161
166
  >>> new_features = torch.randn(200, 100)
162
- >>> new_eigenvectors = propagate_knn(old_eigenvectors, new_features, old_features, knn=3)
167
+ >>> new_eigenvectors = extrapolate_knn(old_features,old_eigenvectors,new_features,knn=3)
163
168
  >>> # new_eigenvectors.shape = (200, 20)
164
169
 
165
170
  """
166
- device = subgraph_output.device if device is None else device
167
-
168
- if knn == 1:
169
- return propagate_nearest(
170
- subgraph_output,
171
- inp_features,
172
- subgraph_features,
173
- chunk_size=chunk_size,
174
- device=device,
175
- move_output_to_cpu=move_output_to_cpu,
176
- )
171
+ device = anchor_output.device if device is None else device
177
172
 
178
173
  # used in nystrom_ncut
179
174
  # propagate eigen_vector from subgraph to full graph
180
- subgraph_output = subgraph_output.to(device)
175
+ anchor_output = anchor_output.to(device)
181
176
 
182
- n_chunks = ceildiv(inp_features.shape[0], chunk_size)
177
+ n_chunks = ceildiv(extrapolation_features.shape[0], chunk_size)
183
178
  V_list = []
184
- for _v in torch.chunk(inp_features, n_chunks, dim=0):
185
- _v = _v.to(device)
186
-
187
- # _A = affinity_from_features(subgraph_features, _v, affinity_focal_gamma, distance).mT
188
- # if knn is not None:
189
- # mask = torch.full_like(_A, True, dtype=torch.bool)
190
- # mask[torch.arange(len(_v))[:, None], _A.topk(knn, dim=-1, largest=True).indices] = False
191
- # _A[mask] = 0.0
192
- # _A = F.normalize(_A, p=1, dim=-1)
193
-
194
- if distance == 'cosine':
195
- _A = _v @ subgraph_features.T
196
- elif distance == 'euclidean':
197
- _A = - torch.cdist(_v, subgraph_features, p=2)
198
- elif distance == 'rbf':
199
- _A = - torch.cdist(_v, subgraph_features, p=2) ** 2
179
+ for _v in torch.chunk(extrapolation_features, n_chunks, dim=0):
180
+ _v = _v.to(device) # [_m x d]
181
+ _A = affinity_from_features(anchor_features, _v, affinity_focal_gamma, distance).mT # [_m x n]
182
+ if knn is not None:
183
+ _A, indices = _A.topk(k=knn, dim=-1, largest=True) # [_m x k], [_m x k]
184
+ _anchor_output = anchor_output[indices] # [_m x k x d]
200
185
  else:
201
- raise ValueError("distance should be 'cosine' or 'euclidean', 'rbf'")
202
-
203
- # keep topk KNN for each row
204
- topk_sim, topk_idx = _A.topk(knn, dim=-1, largest=True)
205
- row_id = torch.arange(topk_idx.shape[0], device=_A.device)[:, None].expand(
206
- -1, topk_idx.shape[1]
207
- )
208
- _A = torch.sparse_coo_tensor(
209
- torch.stack([row_id, topk_idx], dim=-1).reshape(-1, 2).T,
210
- topk_sim.reshape(-1),
211
- size=(_A.shape[0], _A.shape[1]),
212
- device=_A.device,
213
- )
214
- _A = _A.to_dense().to(dtype=subgraph_output.dtype)
215
- _D = _A.sum(-1)
216
- _A /= _D[:, None]
217
-
218
- _V = _A @ subgraph_output
219
- if move_output_to_cpu:
220
- _V = _V.cpu()
221
- V_list.append(_V)
222
-
223
- subgraph_output = torch.cat(V_list, dim=0)
224
- return subgraph_output
225
-
226
-
227
- def propagate_nearest(
228
- subgraph_output: torch.Tensor,
229
- inp_features: torch.Tensor,
230
- subgraph_features: torch.Tensor,
231
- distance: Literal["cosine", "euclidean", "rbf"] = "cosine",
232
- chunk_size: int = 8192,
233
- device: str = None,
234
- move_output_to_cpu: bool = False,
235
- ):
236
- device = subgraph_output.device if device is None else device
237
- if distance == 'cosine':
238
- inp_features = lazy_normalize(inp_features, dim=-1)
239
- subgraph_features = lazy_normalize(subgraph_features, dim=-1)
240
-
241
- # used in nystrom_tsne, equivalent to propagate_by_knn with knn=1
242
- # propagate tSNE from subgraph to full graph
243
- V_list = []
244
- subgraph_features = subgraph_features.to(device)
245
- for i in range(0, inp_features.shape[0], chunk_size):
246
- end = min(i + chunk_size, inp_features.shape[0])
247
- _v = inp_features[i:end].to(device)
248
- _A = -distance_from_features(subgraph_features, _v, distance).mT
249
-
250
- # keep top1 for each row
251
- top_idx = _A.argmax(dim=-1).cpu()
252
- _V = subgraph_output[top_idx]
186
+ _anchor_output = anchor_output[None] # [1 x n x d]
187
+ _A = Fn.normalize(_A, p=1, dim=-1)
188
+
189
+ # if distance == 'cosine':
190
+ # _A = _v @ subgraph_features.T
191
+ # elif distance == 'euclidean':
192
+ # _A = - torch.cdist(_v, subgraph_features, p=2)
193
+ # elif distance == 'rbf':
194
+ # _A = - torch.cdist(_v, subgraph_features, p=2) ** 2
195
+ # else:
196
+ # raise ValueError("distance should be 'cosine' or 'euclidean', 'rbf'")
197
+ #
198
+ # # keep topk KNN for each row
199
+ # topk_sim, topk_idx = _A.topk(knn, dim=-1, largest=True)
200
+ # row_id = torch.arange(topk_idx.shape[0], device=_A.device)[:, None].expand(
201
+ # -1, topk_idx.shape[1]
202
+ # )
203
+ # _A = torch.sparse_coo_tensor(
204
+ # torch.stack([row_id, topk_idx], dim=-1).reshape(-1, 2).T,
205
+ # topk_sim.reshape(-1),
206
+ # size=(_A.shape[0], _A.shape[1]),
207
+ # device=_A.device,
208
+ # )
209
+ # _A = _A.to_dense().to(dtype=subgraph_output.dtype)
210
+ # _D = _A.sum(-1)
211
+ # _A /= _D[:, None]
212
+
213
+ _V = (_A[:, None, :] @ _anchor_output).squeeze(1)
253
214
  if move_output_to_cpu:
254
215
  _V = _V.cpu()
255
216
  V_list.append(_V)
256
217
 
257
- subgraph_output = torch.cat(V_list, dim=0)
258
- return subgraph_output
218
+ anchor_output = torch.cat(V_list, dim=0)
219
+ return anchor_output
259
220
 
260
221
 
261
222
  # wrapper functions for adding new nodes to existing graph
262
- def propagate_eigenvectors(
263
- eigenvectors: torch.Tensor,
264
- features: torch.Tensor,
265
- new_features: torch.Tensor,
223
+ def extrapolate_knn_with_subsampling(
224
+ full_features: torch.Tensor,
225
+ full_output: torch.Tensor,
226
+ extrapolation_features: torch.Tensor,
266
227
  knn: int,
267
228
  num_sample: int,
268
- sample_method: Literal["farthest", "random"],
229
+ sample_method: SampleOptions,
269
230
  chunk_size: int,
270
- device: str,
231
+ device: str
271
232
  ):
272
233
  """Propagate eigenvectors to new nodes using KNN. Note: this is equivalent to the class API `NCUT.tranform(new_features)`, expect for the sampling is re-done in this function.
273
234
  Args:
274
- eigenvectors (torch.Tensor): eigenvectors from existing nodes, shape (num_sample, num_eig)
275
- features (torch.Tensor): features from existing nodes, shape (n_samples, n_features)
276
- new_features (torch.Tensor): features from new nodes, shape (n_new_samples, n_features)
235
+ full_output (torch.Tensor): eigenvectors from existing nodes, shape (num_sample, num_eig)
236
+ full_features (torch.Tensor): features from existing nodes, shape (n_samples, n_features)
237
+ extrapolation_features (torch.Tensor): features from new nodes, shape (n_new_samples, n_features)
277
238
  knn (int): number of KNN to propagate eigenvectors, default 3
278
239
  num_sample (int): number of samples for subgraph sampling, default 50000
279
240
  sample_method (str): sample method, 'farthest' (default) or 'random'
@@ -286,31 +247,31 @@ def propagate_eigenvectors(
286
247
  >>> old_eigenvectors = torch.randn(3000, 20)
287
248
  >>> old_features = torch.randn(3000, 100)
288
249
  >>> new_features = torch.randn(200, 100)
289
- >>> new_eigenvectors = propagate_eigenvectors(old_eigenvectors, new_features, old_features, knn=3)
250
+ >>> new_eigenvectors = extrapolate_knn_with_subsampling(extrapolation_features,old_eigenvectors,old_features,knn=3,num_sample=,sample_method=,chunk_size=,device=)
290
251
  >>> # new_eigenvectors.shape = (200, 20)
291
252
  """
292
253
 
293
- device = eigenvectors.device if device is None else device
254
+ device = full_output.device if device is None else device
294
255
 
295
256
  # sample subgraph
296
- subgraph_indices = run_subgraph_sampling(
297
- features,
257
+ anchor_indices = run_subgraph_sampling(
258
+ full_features,
298
259
  num_sample,
299
260
  sample_method=sample_method,
300
261
  )
301
262
 
302
- subgraph_eigenvectors = eigenvectors[subgraph_indices].to(device)
303
- subgraph_features = features[subgraph_indices].to(device)
304
- new_features = new_features.to(device)
263
+ anchor_output = full_output[anchor_indices].to(device)
264
+ anchor_features = full_features[anchor_indices].to(device)
265
+ extrapolation_features = extrapolation_features.to(device)
305
266
 
306
267
  # propagate eigenvectors from subgraph to new nodes
307
- new_eigenvectors = propagate_knn(
308
- subgraph_eigenvectors,
309
- new_features,
310
- subgraph_features,
268
+ new_eigenvectors = extrapolate_knn(
269
+ anchor_features,
270
+ anchor_output,
271
+ extrapolation_features,
311
272
  knn=knn,
312
273
  chunk_size=chunk_size,
313
- device=device,
274
+ device=device
314
275
  )
315
276
  return new_eigenvectors
316
277
 
@@ -6,11 +6,14 @@ import torch
6
6
  import torch.nn.functional as F
7
7
  from sklearn.base import BaseEstimator
8
8
 
9
- from .common import lazy_normalize
9
+ from .common import (
10
+ DistanceOptions,
11
+ lazy_normalize,
12
+ )
10
13
  from .propagation_utils import (
11
14
  run_subgraph_sampling,
12
- propagate_knn,
13
- propagate_eigenvectors,
15
+ extrapolate_knn,
16
+ extrapolate_knn_with_subsampling,
14
17
  quantile_min_max,
15
18
  quantile_normalize
16
19
  )
@@ -31,14 +34,29 @@ def _rgb_with_dimensionality_reduction(
31
34
  reduction_dim: int,
32
35
  reduction_kwargs: Dict[str, Any],
33
36
  transform_func: Callable[[torch.Tensor], torch.Tensor] = _identity,
37
+ pre_smooth: bool = True,
34
38
  ) -> Tuple[torch.Tensor, torch.Tensor]:
39
+
40
+ if pre_smooth:
41
+ _subgraph_indices = run_subgraph_sampling(
42
+ features,
43
+ num_sample,
44
+ sample_method="farthest",
45
+ )
46
+ features = extrapolate_knn(
47
+ features[_subgraph_indices],
48
+ features[_subgraph_indices],
49
+ features,
50
+ distance="cosine",
51
+ )
52
+
35
53
  subgraph_indices = run_subgraph_sampling(
36
54
  features,
37
55
  num_sample,
38
56
  sample_method="farthest",
39
57
  )
40
58
 
41
- _inp = features[subgraph_indices].cpu().numpy()
59
+ _inp = features[subgraph_indices].numpy(force=True)
42
60
  _subgraph_embed = reduction(
43
61
  n_components=reduction_dim,
44
62
  metric=metric,
@@ -47,14 +65,14 @@ def _rgb_with_dimensionality_reduction(
47
65
  ).fit_transform(_inp)
48
66
 
49
67
  _subgraph_embed = torch.tensor(_subgraph_embed, dtype=torch.float32)
50
- X_nd = transform_func(propagate_knn(
68
+ X_nd = transform_func(extrapolate_knn(
69
+ features[subgraph_indices],
51
70
  _subgraph_embed,
52
71
  features,
53
- features[subgraph_indices],
54
- distance=metric,
55
72
  knn=knn,
73
+ distance=metric,
56
74
  device=device,
57
- move_output_to_cpu=True,
75
+ move_output_to_cpu=True
58
76
  ))
59
77
  rgb = rgb_func(X_nd, q)
60
78
  return X_nd, rgb
@@ -413,48 +431,6 @@ def rgb_from_2d_colormap(X_2d, q=0.95):
413
431
  return rgb
414
432
 
415
433
 
416
- def propagate_rgb_color(
417
- rgb: torch.Tensor,
418
- eigenvectors: torch.Tensor,
419
- new_eigenvectors: torch.Tensor,
420
- knn: int = 10,
421
- num_sample: int = 1000,
422
- sample_method: Literal["farthest", "random"] = "farthest",
423
- chunk_size: int = 8192,
424
- device: str = None,
425
- ):
426
- """Propagate RGB color to new nodes using KNN.
427
- Args:
428
- rgb (torch.Tensor): RGB color for each data sample, shape (n_samples, 3)
429
- features (torch.Tensor): features from existing nodes, shape (n_samples, n_features)
430
- new_features (torch.Tensor): features from new nodes, shape (n_new_samples, n_features)
431
- knn (int): number of KNN to propagate RGB color, default 1
432
- num_sample (int): number of samples for subgraph sampling, default 50000
433
- sample_method (str): sample method, 'farthest' (default) or 'random'
434
- chunk_size (int): chunk size for matrix multiplication, default 8192
435
- device (str): device to use for computation, if None, will not change device
436
- Returns:
437
- torch.Tensor: propagated RGB color for each data sample, shape (n_new_samples, 3)
438
-
439
- Examples:
440
- >>> old_rgb = torch.randn(3000, 3)
441
- >>> old_eigenvectors = torch.randn(3000, 20)
442
- >>> new_eigenvectors = torch.randn(200, 20)
443
- >>> new_rgb = propagate_rgb_color(old_rgb, new_eigenvectors, old_eigenvectors)
444
- >>> # new_eigenvectors.shape = (200, 3)
445
- """
446
- return propagate_eigenvectors(
447
- eigenvectors=rgb,
448
- features=eigenvectors,
449
- new_features=new_eigenvectors,
450
- knn=knn,
451
- num_sample=num_sample,
452
- sample_method=sample_method,
453
- chunk_size=chunk_size,
454
- device=device,
455
- )
456
-
457
-
458
434
  # application: get segmentation mask fron a reference eigenvector (point prompt)
459
435
  def _transform_heatmap(heatmap, gamma=1.0):
460
436
  """Transform the heatmap using gamma, normalize and min-max normalization.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: nystrom_ncut
3
- Version: 0.0.6
3
+ Version: 0.0.7
4
4
  Summary: Normalized Cut and Nyström Approximation
5
5
  Author-email: Huzheng Yang <huze.yann@gmail.com>, Wentinn Liao <wentinn.liao@gmail.com>
6
6
  Project-URL: Documentation, https://github.com/JophiArcana/Nystrom-NCUT/
@@ -0,0 +1,190 @@
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as Fn
4
+ from matplotlib import pyplot as plt
5
+
6
+ from src.nystrom_ncut.ncut_pytorch import NCUT, axis_align, affinity_from_features
7
+ from ncut_pytorch import NCUT as OldNCUT
8
+ # from ncut_pytorch.src import rgb_from_umap_sphere
9
+ # from ncut_pytorch.src.new_ncut_pytorch import NewNCUT
10
+
11
+ # from ncut_pytorch.ncut_pytorch.backbone_text import load_text_model
12
+
13
+
14
+ if __name__ == "__main__":
15
+ # torch.manual_seed(1212)
16
+ # M = torch.randn((7, 3))
17
+ # W = torch.nn.functional.cosine_similarity(M[:, None], M[None, :], dim=-1)
18
+ # A = torch.exp(W - 1)
19
+ # D_s2 = torch.sum(A, dim=-1, keepdim=True) ** -0.5
20
+ # # print(A)
21
+ # print(A * D_s2 * D_s2.mT)
22
+ #
23
+ # ncut = NCUT(num_eig=7, knn=1, eig_solver="svd")
24
+ # V, L = ncut.fit_transform(M)
25
+ # print(V @ torch.diag(L) @ V.mT)
26
+ # raise Exception()
27
+
28
+ # print(load_text_model("meta-llama/Meta-Llama-3.1-8B").cuda())
29
+ # print(AutoModelForCausalLM.from_pretrained(
30
+ # "meta-llama/Meta-Llama-3.1-8B",
31
+ # token="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",
32
+ # ))
33
+ # # print(transformers.pipeline(
34
+ # # "text-generation",
35
+ # # model="meta-llama/Meta-Llama-3.1-8B",
36
+ # # model_kwargs={"torch_dtype": torch.bfloat16},
37
+ # # token="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",
38
+ # # device="cpu",
39
+ # # ))
40
+ # raise Exception(
41
+
42
+ torch.set_printoptions(precision=8, sci_mode=False, linewidth=400)
43
+ torch.set_default_dtype(torch.float64)
44
+ torch.manual_seed(1212)
45
+ np.random.seed(1212)
46
+
47
+ n = 120
48
+ num_sample = 100
49
+
50
+ M = torch.rand((n, 12))
51
+ distance = "rbf"
52
+
53
+ A = affinity_from_features(M, distance=distance)
54
+ R = torch.diag(torch.sum(A, dim=-1) ** -0.5)
55
+ L = R @ A @ R
56
+
57
+ # C = L[num_sample:, num_sample:]
58
+ #
59
+ # _A = L[:num_sample, :num_sample]
60
+ # _B = L[:num_sample, num_sample:]
61
+ # extrapolated_C = _B.mT @ torch.inverse(_A) @ _B
62
+ #
63
+ # RE = torch.abs(extrapolated_C / C - 1)
64
+ # print(torch.max(RE).item(), torch.mean(RE).item(), torch.min(RE).item())
65
+
66
+ n_components = 30 # num_sample
67
+ eig_solver = "svd"
68
+
69
+ def rel_error(X, eigs):
70
+ _L = X @ torch.diag(eigs) @ X.mT
71
+ return torch.abs(_L / L - 1)
72
+
73
+ def print_re(re):
74
+ print(f"max: {re.max().item()}, mean: {re.mean().item()}, min: {re.min().item()}")
75
+
76
+ nc0 = NCUT(n_components=n_components, num_sample=num_sample, distance=distance, eig_solver=eig_solver)
77
+ X0, eigs0 = nc0.fit_transform(M)
78
+
79
+ re0 = rel_error(X0, eigs0)
80
+ print_re(re0)
81
+
82
+ plt.imshow(re0)
83
+ plt.colorbar()
84
+ plt.show()
85
+
86
+ plt.scatter(torch.arange(n), torch.linalg.norm(X0, dim=-1))
87
+ plt.show()
88
+ raise Exception()
89
+
90
+
91
+ #
92
+ # # plt.scatter(torch.arange(n), torch.linalg.norm(X0, dim=-1))
93
+ # # plt.show()
94
+ # # raise Exception()
95
+ #
96
+ # def align_to(X, eigs):
97
+ # sign = torch.sign(torch.sum(X0 * X, dim=0))
98
+ # return X * sign, eigs
99
+ #
100
+ # Xs = []
101
+ # n_trials = 20
102
+ # sum_X, sum_eigs = 0.0, 0.0
103
+ # for _ in range(n_trials):
104
+ # nc = NCUT(n_components=n_components, num_sample=num_sample, distance=distance, eig_solver=eig_solver)
105
+ # X, eigs = align_to(*nc.fit_transform(M))
106
+ # Xs.append(X)
107
+ #
108
+ # re = rel_error(X, eigs)
109
+ # print(f"max: {re.max().item()}, mean: {re.mean().item()}, min: {re.min().item()}")
110
+ #
111
+ # # print(X[:3, :10])
112
+ # # print(eigs[:10])
113
+ #
114
+ # sum_X = sum_X + X
115
+ # sum_eigs = sum_eigs + eigs
116
+ #
117
+ # # print(torch.diag(Xs[0].mT @ Xs[1]))
118
+ # # raise Exception()
119
+ #
120
+ # print("=" * 120)
121
+ # mean_X, mean_eigs = sum_X / n_trials, sum_eigs / n_trials
122
+ # mean_re = rel_error(mean_X, mean_eigs)
123
+ # print(f"max: {mean_re.max().item()}, mean: {mean_re.mean().item()}, min: {mean_re.min().item()}")
124
+ #
125
+ # raise Exception()
126
+
127
+
128
+
129
+ ncs = [
130
+ NCUT(n_components=n_components, num_sample=n, distance=distance, eig_solver=eig_solver),
131
+ NCUT(n_components=n_components, num_sample=num_sample, distance=distance, eig_solver=eig_solver),
132
+ # OldNCUT(num_eig=n_components, num_sample=num_sample, knn=10, distance=distance, eig_solver=eig_solver, make_orthogonal=True),
133
+ ]
134
+
135
+ for NC in ncs:
136
+ torch.manual_seed(1212)
137
+ np.random.seed(1212)
138
+ X, eigs = NC.fit_transform(M)
139
+
140
+ RE = rel_error(X, eigs)
141
+ print(f"max: {RE.max().item()}, mean: {RE.mean().item()}, min: {RE.min().item()}")
142
+
143
+ # torch.manual_seed(1212)
144
+ # np.random.seed(1212)
145
+ #
146
+ # aX, R = axis_align(X)
147
+ # print(aX[:3])
148
+ # print(R)
149
+ # print(R @ R.mT)
150
+
151
+
152
+
153
+
154
+ # import time
155
+ # n_trials = 10
156
+ #
157
+ # with torch.no_grad():
158
+ # start_t = time.perf_counter()
159
+ # for _ in range(n_trials):
160
+ # X, eigs = NC.fit_transform(M)
161
+ # end_t = time.perf_counter()
162
+ # print(X.min().item(), X.max().item(), eigs)
163
+ # print(f"{1e3 * (end_t - start_t) / n_trials}ms")
164
+ #
165
+ # start_t = time.perf_counter()
166
+ # for _ in range(n_trials):
167
+ # nX, neigs = nNC.fit_transform(M)
168
+ # end_t = time.perf_counter()
169
+ # print(nX.min().item(), nX.max().item(), neigs)
170
+ # print(f"{1e3 * (end_t - start_t) / n_trials}ms")
171
+ # raise Exception()
172
+
173
+ # assert torch.all(torch.isclose(X, torch.Tensor([
174
+ # [0.320216, 0.144101, -0.110744, -0.560543, -0.007982],
175
+ # [0.297634, 0.662867, 0.146107, 0.277893, 0.553959],
176
+ # [0.324994, -0.057295, 0.052916, 0.391666, -0.460911],
177
+ # [0.301703, -0.460709, 0.528563, 0.222525, 0.325546],
178
+ # [0.316614, 0.043475, -0.526899, 0.100665, -0.030259],
179
+ # [0.325425, -0.127884, 0.294540, -0.012173, -0.303528],
180
+ # [0.318136, -0.288952, -0.065148, -0.470192, 0.244805],
181
+ # [0.309522, -0.352693, -0.473237, 0.234057, 0.276185],
182
+ # [0.320464, 0.229301, 0.281134, -0.308938, -0.169746],
183
+ # [0.326147, 0.213536, -0.112246, 0.155114, -0.341439]
184
+ # ]), atol=1e-6)), "Failed assertion"
185
+
186
+ # torch.manual_seed(1212)
187
+ # np.random.seed(1212)
188
+ # X_2d, rgb = rgb_from_umap_sphere(X)
189
+ # # X_3d, rgb = rgb_from_cosine_tsne_3d(X)
190
+ # print(rgb)
@@ -1,112 +0,0 @@
1
- import numpy as np
2
- import torch
3
- import torch.nn.functional as Fn
4
-
5
- from src.nystrom_ncut.ncut_pytorch import NCUT, axis_align
6
- # from ncut_pytorch.src import rgb_from_umap_sphere
7
- # from ncut_pytorch.src.new_ncut_pytorch import NewNCUT
8
-
9
- # from ncut_pytorch.ncut_pytorch.backbone_text import load_text_model
10
-
11
-
12
- if __name__ == "__main__":
13
- # torch.manual_seed(1212)
14
- # M = torch.randn((7, 3))
15
- # W = torch.nn.functional.cosine_similarity(M[:, None], M[None, :], dim=-1)
16
- # A = torch.exp(W - 1)
17
- # D_s2 = torch.sum(A, dim=-1, keepdim=True) ** -0.5
18
- # # print(A)
19
- # print(A * D_s2 * D_s2.mT)
20
- #
21
- # ncut = NCUT(num_eig=7, knn=1, eig_solver="svd")
22
- # V, L = ncut.fit_transform(M)
23
- # print(V @ torch.diag(L) @ V.mT)
24
- # raise Exception()
25
-
26
- # print(load_text_model("meta-llama/Meta-Llama-3.1-8B").cuda())
27
- # print(AutoModelForCausalLM.from_pretrained(
28
- # "meta-llama/Meta-Llama-3.1-8B",
29
- # token="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",
30
- # ))
31
- # # print(transformers.pipeline(
32
- # # "text-generation",
33
- # # model="meta-llama/Meta-Llama-3.1-8B",
34
- # # model_kwargs={"torch_dtype": torch.bfloat16},
35
- # # token="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",
36
- # # device="cpu",
37
- # # ))
38
- # raise Exception(
39
-
40
- torch.set_printoptions(precision=8, sci_mode=False, linewidth=400)
41
- torch.set_default_dtype(torch.float32)
42
- torch.manual_seed(1212)
43
- np.random.seed(1212)
44
-
45
- M = torch.rand((1200, 12))
46
- NC = NCUT(n_components=30, num_sample=1000, sample_method="farthest", eig_solver="svd")
47
-
48
- torch.manual_seed(1212)
49
- np.random.seed(1212)
50
- X, eigs = NC.fit_transform(M)
51
- print(eigs)
52
- # print(X.mT @ X)
53
-
54
- normalized_M = Fn.normalize(M, p=2, dim=-1)
55
- A = torch.exp(-(1 - normalized_M @ normalized_M.mT))
56
- R = torch.diag(torch.sum(A, dim=-1) ** -0.5)
57
- L = R @ A @ R
58
- # print(L)
59
- # print(X @ torch.diag(eigs) @ X.mT)
60
- # print(L)
61
- RE = torch.abs(X @ torch.diag(eigs) @ X.mT / L - 1)
62
- print(RE.max().item(), RE.mean().item())
63
-
64
- # torch.manual_seed(1212)
65
- # np.random.seed(1212)
66
- #
67
- # aX, R = axis_align(X)
68
- # print(aX[:3])
69
- # print(R)
70
- # print(R @ R.mT)
71
- raise Exception()
72
-
73
-
74
-
75
-
76
- # import time
77
- # n_trials = 10
78
- #
79
- # with torch.no_grad():
80
- # start_t = time.perf_counter()
81
- # for _ in range(n_trials):
82
- # X, eigs = NC.fit_transform(M)
83
- # end_t = time.perf_counter()
84
- # print(X.min().item(), X.max().item(), eigs)
85
- # print(f"{1e3 * (end_t - start_t) / n_trials}ms")
86
- #
87
- # start_t = time.perf_counter()
88
- # for _ in range(n_trials):
89
- # nX, neigs = nNC.fit_transform(M)
90
- # end_t = time.perf_counter()
91
- # print(nX.min().item(), nX.max().item(), neigs)
92
- # print(f"{1e3 * (end_t - start_t) / n_trials}ms")
93
- # raise Exception()
94
-
95
- # assert torch.all(torch.isclose(X, torch.Tensor([
96
- # [0.320216, 0.144101, -0.110744, -0.560543, -0.007982],
97
- # [0.297634, 0.662867, 0.146107, 0.277893, 0.553959],
98
- # [0.324994, -0.057295, 0.052916, 0.391666, -0.460911],
99
- # [0.301703, -0.460709, 0.528563, 0.222525, 0.325546],
100
- # [0.316614, 0.043475, -0.526899, 0.100665, -0.030259],
101
- # [0.325425, -0.127884, 0.294540, -0.012173, -0.303528],
102
- # [0.318136, -0.288952, -0.065148, -0.470192, 0.244805],
103
- # [0.309522, -0.352693, -0.473237, 0.234057, 0.276185],
104
- # [0.320464, 0.229301, 0.281134, -0.308938, -0.169746],
105
- # [0.326147, 0.213536, -0.112246, 0.155114, -0.341439]
106
- # ]), atol=1e-6)), "Failed assertion"
107
-
108
- torch.manual_seed(1212)
109
- np.random.seed(1212)
110
- X_2d, rgb = rgb_from_umap_sphere(X)
111
- # X_3d, rgb = rgb_from_cosine_tsne_3d(X)
112
- print(rgb)
File without changes
File without changes
File without changes
File without changes