nystrom-ncut 0.0.5__tar.gz → 0.0.7__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: nystrom_ncut
3
- Version: 0.0.5
3
+ Version: 0.0.7
4
4
  Summary: Normalized Cut and Nyström Approximation
5
5
  Author-email: Huzheng Yang <huze.yann@gmail.com>, Wentinn Liao <wentinn.liao@gmail.com>
6
6
  Project-URL: Documentation, https://github.com/JophiArcana/Nystrom-NCUT/
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "nystrom_ncut"
7
- version = "0.0.5"
7
+ version = "0.0.7"
8
8
  authors = [
9
9
  { name = "Huzheng Yang", email = "huze.yann@gmail.com" },
10
10
  { name = "Wentinn Liao", email = "wentinn.liao@gmail.com" },
@@ -4,8 +4,8 @@ from .ncut_pytorch import (
4
4
  )
5
5
  from .propagation_utils import (
6
6
  affinity_from_features,
7
- propagate_eigenvectors,
8
- propagate_knn,
7
+ extrapolate_knn_with_subsampling,
8
+ extrapolate_knn,
9
9
  quantile_normalize,
10
10
  )
11
11
  from .visualize_utils import (
@@ -17,6 +17,5 @@ from .visualize_utils import (
17
17
  rgb_from_cosine_tsne_3d,
18
18
  rotate_rgb_cube,
19
19
  convert_to_lab_color,
20
- propagate_rgb_color,
21
20
  get_mask,
22
21
  )
@@ -1,10 +1,14 @@
1
- from typing import Any
1
+ from typing import Any, Literal
2
2
 
3
3
  import numpy as np
4
4
  import torch
5
5
  import torch.nn.functional as Fn
6
6
 
7
7
 
8
+ DistanceOptions = Literal["cosine", "euclidean", "rbf"]
9
+ SampleOptions = Literal["farthest", "random"]
10
+
11
+
8
12
  def ceildiv(a: int, b: int) -> int:
9
13
  return -(-a // b)
10
14
 
@@ -4,6 +4,10 @@ from typing import Literal, Tuple
4
4
  import torch
5
5
  import torch.nn.functional as Fn
6
6
 
7
+ from .common import (
8
+ DistanceOptions,
9
+ SampleOptions,
10
+ )
7
11
  from .nystrom import (
8
12
  EigSolverOptions,
9
13
  OnlineKernel,
@@ -16,9 +20,6 @@ from .propagation_utils import (
16
20
  )
17
21
 
18
22
 
19
- DistanceOptions = Literal["cosine", "euclidean", "rbf"]
20
-
21
-
22
23
  class LaplacianKernel(OnlineKernel):
23
24
  def __init__(
24
25
  self,
@@ -46,9 +47,10 @@ class LaplacianKernel(OnlineKernel):
46
47
  affinity_focal_gamma=self.affinity_focal_gamma,
47
48
  distance=self.distance,
48
49
  ) # [n x n]
50
+ d = features.shape[-1]
49
51
  U, L = solve_eig(
50
52
  self.A,
51
- num_eig=features.shape[-1] + 1,
53
+ num_eig=d + 1, # d * (d + 3) // 2 + 1,
52
54
  eig_solver=self.eig_solver,
53
55
  ) # [n x (d + 1)], [d + 1]
54
56
  self.Ainv = U @ torch.diag(1 / L) @ U.mT # [n x n]
@@ -97,11 +99,10 @@ class NCUT(OnlineNystrom):
97
99
  n_components: int = 100,
98
100
  affinity_focal_gamma: float = 1.0,
99
101
  num_sample: int = 10000,
100
- sample_method: Literal["farthest", "random"] = "farthest",
102
+ sample_method: SampleOptions = "farthest",
101
103
  distance: DistanceOptions = "cosine",
102
104
  eig_solver: EigSolverOptions = "svd_lowrank",
103
105
  normalize_features: bool = None,
104
- move_output_to_cpu: bool = False,
105
106
  chunk_size: int = 8192,
106
107
  ):
107
108
  """
@@ -117,7 +118,6 @@ class NCUT(OnlineNystrom):
117
118
  eig_solver (str): eigen decompose solver, ['svd_lowrank', 'lobpcg', 'svd', 'eigh'].
118
119
  normalize_features (bool): normalize input features before computing affinity matrix,
119
120
  default 'None' is True for cosine distance, False for euclidean distance and rbf
120
- move_output_to_cpu (bool): move output to CPU, set to True if you have memory issue
121
121
  chunk_size (int): chunk size for large-scale matrix multiplication
122
122
  """
123
123
  OnlineNystrom.__init__(
@@ -127,18 +127,18 @@ class NCUT(OnlineNystrom):
127
127
  eig_solver=eig_solver,
128
128
  chunk_size=chunk_size,
129
129
  )
130
- self.num_sample = num_sample
131
- self.sample_method = sample_method
132
- self.distance = distance
133
- self.normalize_features = normalize_features
130
+ self.num_sample: int = num_sample
131
+ self.sample_method: SampleOptions = sample_method
132
+ self.anchor_indices: torch.Tensor = None
133
+ self.distance: DistanceOptions = distance
134
+ self.normalize_features: bool = normalize_features
134
135
  if self.normalize_features is None:
135
136
  if distance in ["cosine"]:
136
137
  self.normalize_features = True
137
138
  if distance in ["euclidean", "rbf"]:
138
139
  self.normalize_features = False
139
140
 
140
- self.move_output_to_cpu = move_output_to_cpu
141
- self.chunk_size = chunk_size
141
+ self.chunk_size: int = chunk_size
142
142
 
143
143
  def _fit_helper(
144
144
  self,
@@ -152,16 +152,6 @@ class NCUT(OnlineNystrom):
152
152
  )
153
153
  self.num_sample = _n
154
154
 
155
- # check if features dimension greater than num_eig
156
- if self.eig_solver in ["svd_lowrank", "lobpcg"]:
157
- assert (
158
- _n >= self.n_components * 2
159
- ), "number of nodes should be greater than 2*num_eig"
160
- elif self.eig_solver in ["svd", "eigh"]:
161
- assert (
162
- _n >= self.n_components
163
- ), "number of nodes should be greater than num_eig"
164
-
165
155
  assert self.distance in ["cosine", "euclidean", "rbf"], "distance should be 'cosine', 'euclidean', 'rbf'"
166
156
 
167
157
  if self.normalize_features:
@@ -169,20 +159,20 @@ class NCUT(OnlineNystrom):
169
159
  features = torch.nn.functional.normalize(features, dim=-1)
170
160
 
171
161
  if precomputed_sampled_indices is not None:
172
- sampled_indices = precomputed_sampled_indices
162
+ _sampled_indices = precomputed_sampled_indices
173
163
  else:
174
- sampled_indices = run_subgraph_sampling(
164
+ _sampled_indices = run_subgraph_sampling(
175
165
  features,
176
166
  self.num_sample,
177
167
  sample_method=self.sample_method,
178
168
  )
179
- sampled_indices = torch.sort(sampled_indices).values
180
- sampled_features = features[sampled_indices]
169
+ self.anchor_indices = torch.sort(_sampled_indices).values
170
+ sampled_features = features[self.anchor_indices]
181
171
  OnlineNystrom.fit(self, sampled_features)
182
172
 
183
173
  _n_not_sampled = _n - len(sampled_features)
184
174
  if _n_not_sampled > 0:
185
- unsampled_indices = torch.full((_n,), True, device=features.device).scatter_(0, sampled_indices, False)
175
+ unsampled_indices = torch.full((_n,), True, device=features.device).scatter_(0, self.anchor_indices, False)
186
176
  unsampled_features = features[unsampled_indices]
187
177
  V_unsampled, _ = OnlineNystrom.update(self, unsampled_features)
188
178
  else:
@@ -72,7 +72,7 @@ class OnlineNystrom:
72
72
  self.anchor_features = features
73
73
 
74
74
  self.kernel.fit(self.anchor_features)
75
- self.inverse_approximation_dim = max(self.n_components, features.shape[-1]) + 1
75
+ self.inverse_approximation_dim = max(self.n_components, features.shape[-1] + 1)
76
76
  U, L = self._update_to_kernel() # [n x (? + 1)], [? + 1]
77
77
 
78
78
  self.transform_matrix = (U / L)[:, :self.n_components] # [n x n_components]
@@ -135,7 +135,7 @@ class OnlineNystrom:
135
135
  def solve_eig(
136
136
  A: torch.Tensor,
137
137
  num_eig: int,
138
- eig_solver: Literal["svd_lowrank", "lobpcg", "svd", "eigh"],
138
+ eig_solver: EigSolverOptions,
139
139
  ) -> Tuple[torch.Tensor, torch.Tensor]:
140
140
  """PyTorch implementation of Eigensolver cut without Nystrom-like approximation.
141
141
 
@@ -3,9 +3,14 @@ from typing import Literal
3
3
 
4
4
  import numpy as np
5
5
  import torch
6
- import torch.nn.functional as F
6
+ import torch.nn.functional as Fn
7
7
 
8
- from .common import ceildiv, lazy_normalize
8
+ from .common import (
9
+ DistanceOptions,
10
+ SampleOptions,
11
+ ceildiv,
12
+ lazy_normalize,
13
+ )
9
14
 
10
15
 
11
16
  @torch.no_grad()
@@ -13,7 +18,7 @@ def run_subgraph_sampling(
13
18
  features: torch.Tensor,
14
19
  num_sample: int,
15
20
  max_draw: int = 1000000,
16
- sample_method: Literal["farthest", "random"] = "farthest",
21
+ sample_method: SampleOptions = "farthest",
17
22
  ):
18
23
  if num_sample >= features.shape[0]:
19
24
  # if too many samples, use all samples and bypass Nystrom-like approximation
@@ -74,7 +79,7 @@ def farthest_point_sampling(
74
79
  def distance_from_features(
75
80
  features: torch.Tensor,
76
81
  features_B: torch.Tensor,
77
- distance: Literal["cosine", "euclidean", "rbf"],
82
+ distance: DistanceOptions,
78
83
  ):
79
84
  """Compute affinity matrix from input features.
80
85
  Args:
@@ -96,7 +101,6 @@ def distance_from_features(
96
101
  D = D / (2 * features.var(dim=0).sum())
97
102
  else:
98
103
  raise ValueError("distance should be 'cosine' or 'euclidean', 'rbf'")
99
-
100
104
  return D
101
105
 
102
106
 
@@ -104,7 +108,7 @@ def affinity_from_features(
104
108
  features: torch.Tensor,
105
109
  features_B: torch.Tensor = None,
106
110
  affinity_focal_gamma: float = 1.0,
107
- distance: Literal["cosine", "euclidean", "rbf"] = "cosine",
111
+ distance: DistanceOptions = "cosine",
108
112
  ):
109
113
  """Compute affinity matrix from input features.
110
114
 
@@ -132,23 +136,23 @@ def affinity_from_features(
132
136
  return A
133
137
 
134
138
 
135
- def propagate_knn(
136
- subgraph_output: torch.Tensor,
137
- inp_features: torch.Tensor,
138
- subgraph_features: torch.Tensor,
139
- knn: int = 10,
140
- distance: Literal["cosine", "euclidean", "rbf"] = "cosine",
139
+ def extrapolate_knn(
140
+ anchor_features: torch.Tensor, # [n x d]
141
+ anchor_output: torch.Tensor, # [n x d']
142
+ extrapolation_features: torch.Tensor, # [m x d]
143
+ knn: int = 10, # k
144
+ distance: DistanceOptions = "cosine",
141
145
  affinity_focal_gamma: float = 1.0,
142
146
  chunk_size: int = 8192,
143
147
  device: str = None,
144
- move_output_to_cpu: bool = False,
145
- ):
148
+ move_output_to_cpu: bool = False
149
+ ) -> torch.Tensor: # [m x d']
146
150
  """A generic function to propagate new nodes using KNN.
147
151
 
148
152
  Args:
149
- subgraph_output (torch.Tensor): output from subgraph, shape (num_sample, D)
150
- inp_features (torch.Tensor): features from existing nodes, shape (new_num_samples, n_features)
151
- subgraph_features (torch.Tensor): features from subgraph, shape (num_sample, n_features)
153
+ anchor_features (torch.Tensor): features from subgraph, shape (num_sample, n_features)
154
+ anchor_output (torch.Tensor): output from subgraph, shape (num_sample, D)
155
+ extrapolation_features (torch.Tensor): features from existing nodes, shape (new_num_samples, n_features)
152
156
  knn (int): number of KNN to propagate eige nvectors
153
157
  distance (str): distance metric, 'cosine' (default) or 'euclidean', 'rbf'
154
158
  chunk_size (int): chunk size for matrix multiplication
@@ -160,97 +164,77 @@ def propagate_knn(
160
164
  >>> old_eigenvectors = torch.randn(3000, 20)
161
165
  >>> old_features = torch.randn(3000, 100)
162
166
  >>> new_features = torch.randn(200, 100)
163
- >>> new_eigenvectors = propagate_knn(old_eigenvectors, new_features, old_features, knn=3)
167
+ >>> new_eigenvectors = extrapolate_knn(old_features,old_eigenvectors,new_features,knn=3)
164
168
  >>> # new_eigenvectors.shape = (200, 20)
165
169
 
166
170
  """
167
- device = subgraph_output.device if device is None else device
168
-
169
- if knn == 1:
170
- return propagate_nearest(
171
- subgraph_output,
172
- inp_features,
173
- subgraph_features,
174
- chunk_size=chunk_size,
175
- device=device,
176
- move_output_to_cpu=move_output_to_cpu,
177
- )
171
+ device = anchor_output.device if device is None else device
178
172
 
179
173
  # used in nystrom_ncut
180
174
  # propagate eigen_vector from subgraph to full graph
181
- subgraph_output = subgraph_output.to(device)
175
+ anchor_output = anchor_output.to(device)
182
176
 
183
- n_chunks = ceildiv(inp_features.shape[0], chunk_size)
177
+ n_chunks = ceildiv(extrapolation_features.shape[0], chunk_size)
184
178
  V_list = []
185
- for _v in torch.chunk(inp_features, n_chunks, dim=0):
186
- _v = _v.to(device)
187
- _A = affinity_from_features(subgraph_features, _v, affinity_focal_gamma, distance).mT
188
-
179
+ for _v in torch.chunk(extrapolation_features, n_chunks, dim=0):
180
+ _v = _v.to(device) # [_m x d]
181
+ _A = affinity_from_features(anchor_features, _v, affinity_focal_gamma, distance).mT # [_m x n]
189
182
  if knn is not None:
190
- mask = torch.full_like(_A, True, dtype=torch.bool)
191
- mask[torch.arange(len(_v))[:, None], _A.topk(knn, dim=-1, largest=True).indices] = False
192
- _A[mask] = 0.0
193
- _A = F.normalize(_A, p=1, dim=-1)
194
-
195
- _V = _A @ subgraph_output
196
- if move_output_to_cpu:
197
- _V = _V.cpu()
198
- V_list.append(_V)
199
-
200
- subgraph_output = torch.cat(V_list, dim=0)
201
- return subgraph_output
202
-
203
-
204
- def propagate_nearest(
205
- subgraph_output: torch.Tensor,
206
- inp_features: torch.Tensor,
207
- subgraph_features: torch.Tensor,
208
- distance: Literal["cosine", "euclidean", "rbf"] = "cosine",
209
- chunk_size: int = 8192,
210
- device: str = None,
211
- move_output_to_cpu: bool = False,
212
- ):
213
- device = subgraph_output.device if device is None else device
214
- if distance == 'cosine':
215
- inp_features = lazy_normalize(inp_features, dim=-1)
216
- subgraph_features = lazy_normalize(subgraph_features, dim=-1)
217
-
218
- # used in nystrom_tsne, equivalent to propagate_by_knn with knn=1
219
- # propagate tSNE from subgraph to full graph
220
- V_list = []
221
- subgraph_features = subgraph_features.to(device)
222
- for i in range(0, inp_features.shape[0], chunk_size):
223
- end = min(i + chunk_size, inp_features.shape[0])
224
- _v = inp_features[i:end].to(device)
225
- _A = -distance_from_features(subgraph_features, _v, distance).mT
226
-
227
- # keep top1 for each row
228
- top_idx = _A.argmax(dim=-1).cpu()
229
- _V = subgraph_output[top_idx]
183
+ _A, indices = _A.topk(k=knn, dim=-1, largest=True) # [_m x k], [_m x k]
184
+ _anchor_output = anchor_output[indices] # [_m x k x d]
185
+ else:
186
+ _anchor_output = anchor_output[None] # [1 x n x d]
187
+ _A = Fn.normalize(_A, p=1, dim=-1)
188
+
189
+ # if distance == 'cosine':
190
+ # _A = _v @ subgraph_features.T
191
+ # elif distance == 'euclidean':
192
+ # _A = - torch.cdist(_v, subgraph_features, p=2)
193
+ # elif distance == 'rbf':
194
+ # _A = - torch.cdist(_v, subgraph_features, p=2) ** 2
195
+ # else:
196
+ # raise ValueError("distance should be 'cosine' or 'euclidean', 'rbf'")
197
+ #
198
+ # # keep topk KNN for each row
199
+ # topk_sim, topk_idx = _A.topk(knn, dim=-1, largest=True)
200
+ # row_id = torch.arange(topk_idx.shape[0], device=_A.device)[:, None].expand(
201
+ # -1, topk_idx.shape[1]
202
+ # )
203
+ # _A = torch.sparse_coo_tensor(
204
+ # torch.stack([row_id, topk_idx], dim=-1).reshape(-1, 2).T,
205
+ # topk_sim.reshape(-1),
206
+ # size=(_A.shape[0], _A.shape[1]),
207
+ # device=_A.device,
208
+ # )
209
+ # _A = _A.to_dense().to(dtype=subgraph_output.dtype)
210
+ # _D = _A.sum(-1)
211
+ # _A /= _D[:, None]
212
+
213
+ _V = (_A[:, None, :] @ _anchor_output).squeeze(1)
230
214
  if move_output_to_cpu:
231
215
  _V = _V.cpu()
232
216
  V_list.append(_V)
233
217
 
234
- subgraph_output = torch.cat(V_list, dim=0)
235
- return subgraph_output
218
+ anchor_output = torch.cat(V_list, dim=0)
219
+ return anchor_output
236
220
 
237
221
 
238
222
  # wrapper functions for adding new nodes to existing graph
239
- def propagate_eigenvectors(
240
- eigenvectors: torch.Tensor,
241
- features: torch.Tensor,
242
- new_features: torch.Tensor,
223
+ def extrapolate_knn_with_subsampling(
224
+ full_features: torch.Tensor,
225
+ full_output: torch.Tensor,
226
+ extrapolation_features: torch.Tensor,
243
227
  knn: int,
244
228
  num_sample: int,
245
- sample_method: Literal["farthest", "random"],
229
+ sample_method: SampleOptions,
246
230
  chunk_size: int,
247
- device: str,
231
+ device: str
248
232
  ):
249
233
  """Propagate eigenvectors to new nodes using KNN. Note: this is equivalent to the class API `NCUT.tranform(new_features)`, expect for the sampling is re-done in this function.
250
234
  Args:
251
- eigenvectors (torch.Tensor): eigenvectors from existing nodes, shape (num_sample, num_eig)
252
- features (torch.Tensor): features from existing nodes, shape (n_samples, n_features)
253
- new_features (torch.Tensor): features from new nodes, shape (n_new_samples, n_features)
235
+ full_output (torch.Tensor): eigenvectors from existing nodes, shape (num_sample, num_eig)
236
+ full_features (torch.Tensor): features from existing nodes, shape (n_samples, n_features)
237
+ extrapolation_features (torch.Tensor): features from new nodes, shape (n_new_samples, n_features)
254
238
  knn (int): number of KNN to propagate eigenvectors, default 3
255
239
  num_sample (int): number of samples for subgraph sampling, default 50000
256
240
  sample_method (str): sample method, 'farthest' (default) or 'random'
@@ -263,31 +247,31 @@ def propagate_eigenvectors(
263
247
  >>> old_eigenvectors = torch.randn(3000, 20)
264
248
  >>> old_features = torch.randn(3000, 100)
265
249
  >>> new_features = torch.randn(200, 100)
266
- >>> new_eigenvectors = propagate_eigenvectors(old_eigenvectors, new_features, old_features, knn=3)
250
+ >>> new_eigenvectors = extrapolate_knn_with_subsampling(extrapolation_features,old_eigenvectors,old_features,knn=3,num_sample=,sample_method=,chunk_size=,device=)
267
251
  >>> # new_eigenvectors.shape = (200, 20)
268
252
  """
269
253
 
270
- device = eigenvectors.device if device is None else device
254
+ device = full_output.device if device is None else device
271
255
 
272
256
  # sample subgraph
273
- subgraph_indices = run_subgraph_sampling(
274
- features,
257
+ anchor_indices = run_subgraph_sampling(
258
+ full_features,
275
259
  num_sample,
276
260
  sample_method=sample_method,
277
261
  )
278
262
 
279
- subgraph_eigenvectors = eigenvectors[subgraph_indices].to(device)
280
- subgraph_features = features[subgraph_indices].to(device)
281
- new_features = new_features.to(device)
263
+ anchor_output = full_output[anchor_indices].to(device)
264
+ anchor_features = full_features[anchor_indices].to(device)
265
+ extrapolation_features = extrapolation_features.to(device)
282
266
 
283
267
  # propagate eigenvectors from subgraph to new nodes
284
- new_eigenvectors = propagate_knn(
285
- subgraph_eigenvectors,
286
- new_features,
287
- subgraph_features,
268
+ new_eigenvectors = extrapolate_knn(
269
+ anchor_features,
270
+ anchor_output,
271
+ extrapolation_features,
288
272
  knn=knn,
289
273
  chunk_size=chunk_size,
290
- device=device,
274
+ device=device
291
275
  )
292
276
  return new_eigenvectors
293
277
 
@@ -6,11 +6,14 @@ import torch
6
6
  import torch.nn.functional as F
7
7
  from sklearn.base import BaseEstimator
8
8
 
9
- from .common import lazy_normalize
9
+ from .common import (
10
+ DistanceOptions,
11
+ lazy_normalize,
12
+ )
10
13
  from .propagation_utils import (
11
14
  run_subgraph_sampling,
12
- propagate_knn,
13
- propagate_eigenvectors,
15
+ extrapolate_knn,
16
+ extrapolate_knn_with_subsampling,
14
17
  quantile_min_max,
15
18
  quantile_normalize
16
19
  )
@@ -31,14 +34,29 @@ def _rgb_with_dimensionality_reduction(
31
34
  reduction_dim: int,
32
35
  reduction_kwargs: Dict[str, Any],
33
36
  transform_func: Callable[[torch.Tensor], torch.Tensor] = _identity,
37
+ pre_smooth: bool = True,
34
38
  ) -> Tuple[torch.Tensor, torch.Tensor]:
39
+
40
+ if pre_smooth:
41
+ _subgraph_indices = run_subgraph_sampling(
42
+ features,
43
+ num_sample,
44
+ sample_method="farthest",
45
+ )
46
+ features = extrapolate_knn(
47
+ features[_subgraph_indices],
48
+ features[_subgraph_indices],
49
+ features,
50
+ distance="cosine",
51
+ )
52
+
35
53
  subgraph_indices = run_subgraph_sampling(
36
54
  features,
37
55
  num_sample,
38
56
  sample_method="farthest",
39
57
  )
40
58
 
41
- _inp = features[subgraph_indices].cpu().numpy()
59
+ _inp = features[subgraph_indices].numpy(force=True)
42
60
  _subgraph_embed = reduction(
43
61
  n_components=reduction_dim,
44
62
  metric=metric,
@@ -47,14 +65,14 @@ def _rgb_with_dimensionality_reduction(
47
65
  ).fit_transform(_inp)
48
66
 
49
67
  _subgraph_embed = torch.tensor(_subgraph_embed, dtype=torch.float32)
50
- X_nd = transform_func(propagate_knn(
68
+ X_nd = transform_func(extrapolate_knn(
69
+ features[subgraph_indices],
51
70
  _subgraph_embed,
52
71
  features,
53
- features[subgraph_indices],
54
- distance=metric,
55
72
  knn=knn,
73
+ distance=metric,
56
74
  device=device,
57
- move_output_to_cpu=True,
75
+ move_output_to_cpu=True
58
76
  ))
59
77
  rgb = rgb_func(X_nd, q)
60
78
  return X_nd, rgb
@@ -413,48 +431,6 @@ def rgb_from_2d_colormap(X_2d, q=0.95):
413
431
  return rgb
414
432
 
415
433
 
416
- def propagate_rgb_color(
417
- rgb: torch.Tensor,
418
- eigenvectors: torch.Tensor,
419
- new_eigenvectors: torch.Tensor,
420
- knn: int = 10,
421
- num_sample: int = 1000,
422
- sample_method: Literal["farthest", "random"] = "farthest",
423
- chunk_size: int = 8192,
424
- device: str = None,
425
- ):
426
- """Propagate RGB color to new nodes using KNN.
427
- Args:
428
- rgb (torch.Tensor): RGB color for each data sample, shape (n_samples, 3)
429
- features (torch.Tensor): features from existing nodes, shape (n_samples, n_features)
430
- new_features (torch.Tensor): features from new nodes, shape (n_new_samples, n_features)
431
- knn (int): number of KNN to propagate RGB color, default 1
432
- num_sample (int): number of samples for subgraph sampling, default 50000
433
- sample_method (str): sample method, 'farthest' (default) or 'random'
434
- chunk_size (int): chunk size for matrix multiplication, default 8192
435
- device (str): device to use for computation, if None, will not change device
436
- Returns:
437
- torch.Tensor: propagated RGB color for each data sample, shape (n_new_samples, 3)
438
-
439
- Examples:
440
- >>> old_rgb = torch.randn(3000, 3)
441
- >>> old_eigenvectors = torch.randn(3000, 20)
442
- >>> new_eigenvectors = torch.randn(200, 20)
443
- >>> new_rgb = propagate_rgb_color(old_rgb, new_eigenvectors, old_eigenvectors)
444
- >>> # new_eigenvectors.shape = (200, 3)
445
- """
446
- return propagate_eigenvectors(
447
- eigenvectors=rgb,
448
- features=eigenvectors,
449
- new_features=new_eigenvectors,
450
- knn=knn,
451
- num_sample=num_sample,
452
- sample_method=sample_method,
453
- chunk_size=chunk_size,
454
- device=device,
455
- )
456
-
457
-
458
434
  # application: get segmentation mask fron a reference eigenvector (point prompt)
459
435
  def _transform_heatmap(heatmap, gamma=1.0):
460
436
  """Transform the heatmap using gamma, normalize and min-max normalization.
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: nystrom_ncut
3
- Version: 0.0.5
3
+ Version: 0.0.7
4
4
  Summary: Normalized Cut and Nyström Approximation
5
5
  Author-email: Huzheng Yang <huze.yann@gmail.com>, Wentinn Liao <wentinn.liao@gmail.com>
6
6
  Project-URL: Documentation, https://github.com/JophiArcana/Nystrom-NCUT/
@@ -0,0 +1,190 @@
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as Fn
4
+ from matplotlib import pyplot as plt
5
+
6
+ from src.nystrom_ncut.ncut_pytorch import NCUT, axis_align, affinity_from_features
7
+ from ncut_pytorch import NCUT as OldNCUT
8
+ # from ncut_pytorch.src import rgb_from_umap_sphere
9
+ # from ncut_pytorch.src.new_ncut_pytorch import NewNCUT
10
+
11
+ # from ncut_pytorch.ncut_pytorch.backbone_text import load_text_model
12
+
13
+
14
+ if __name__ == "__main__":
15
+ # torch.manual_seed(1212)
16
+ # M = torch.randn((7, 3))
17
+ # W = torch.nn.functional.cosine_similarity(M[:, None], M[None, :], dim=-1)
18
+ # A = torch.exp(W - 1)
19
+ # D_s2 = torch.sum(A, dim=-1, keepdim=True) ** -0.5
20
+ # # print(A)
21
+ # print(A * D_s2 * D_s2.mT)
22
+ #
23
+ # ncut = NCUT(num_eig=7, knn=1, eig_solver="svd")
24
+ # V, L = ncut.fit_transform(M)
25
+ # print(V @ torch.diag(L) @ V.mT)
26
+ # raise Exception()
27
+
28
+ # print(load_text_model("meta-llama/Meta-Llama-3.1-8B").cuda())
29
+ # print(AutoModelForCausalLM.from_pretrained(
30
+ # "meta-llama/Meta-Llama-3.1-8B",
31
+ # token="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",
32
+ # ))
33
+ # # print(transformers.pipeline(
34
+ # # "text-generation",
35
+ # # model="meta-llama/Meta-Llama-3.1-8B",
36
+ # # model_kwargs={"torch_dtype": torch.bfloat16},
37
+ # # token="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",
38
+ # # device="cpu",
39
+ # # ))
40
+ # raise Exception(
41
+
42
+ torch.set_printoptions(precision=8, sci_mode=False, linewidth=400)
43
+ torch.set_default_dtype(torch.float64)
44
+ torch.manual_seed(1212)
45
+ np.random.seed(1212)
46
+
47
+ n = 120
48
+ num_sample = 100
49
+
50
+ M = torch.rand((n, 12))
51
+ distance = "rbf"
52
+
53
+ A = affinity_from_features(M, distance=distance)
54
+ R = torch.diag(torch.sum(A, dim=-1) ** -0.5)
55
+ L = R @ A @ R
56
+
57
+ # C = L[num_sample:, num_sample:]
58
+ #
59
+ # _A = L[:num_sample, :num_sample]
60
+ # _B = L[:num_sample, num_sample:]
61
+ # extrapolated_C = _B.mT @ torch.inverse(_A) @ _B
62
+ #
63
+ # RE = torch.abs(extrapolated_C / C - 1)
64
+ # print(torch.max(RE).item(), torch.mean(RE).item(), torch.min(RE).item())
65
+
66
+ n_components = 30 # num_sample
67
+ eig_solver = "svd"
68
+
69
+ def rel_error(X, eigs):
70
+ _L = X @ torch.diag(eigs) @ X.mT
71
+ return torch.abs(_L / L - 1)
72
+
73
+ def print_re(re):
74
+ print(f"max: {re.max().item()}, mean: {re.mean().item()}, min: {re.min().item()}")
75
+
76
+ nc0 = NCUT(n_components=n_components, num_sample=num_sample, distance=distance, eig_solver=eig_solver)
77
+ X0, eigs0 = nc0.fit_transform(M)
78
+
79
+ re0 = rel_error(X0, eigs0)
80
+ print_re(re0)
81
+
82
+ plt.imshow(re0)
83
+ plt.colorbar()
84
+ plt.show()
85
+
86
+ plt.scatter(torch.arange(n), torch.linalg.norm(X0, dim=-1))
87
+ plt.show()
88
+ raise Exception()
89
+
90
+
91
+ #
92
+ # # plt.scatter(torch.arange(n), torch.linalg.norm(X0, dim=-1))
93
+ # # plt.show()
94
+ # # raise Exception()
95
+ #
96
+ # def align_to(X, eigs):
97
+ # sign = torch.sign(torch.sum(X0 * X, dim=0))
98
+ # return X * sign, eigs
99
+ #
100
+ # Xs = []
101
+ # n_trials = 20
102
+ # sum_X, sum_eigs = 0.0, 0.0
103
+ # for _ in range(n_trials):
104
+ # nc = NCUT(n_components=n_components, num_sample=num_sample, distance=distance, eig_solver=eig_solver)
105
+ # X, eigs = align_to(*nc.fit_transform(M))
106
+ # Xs.append(X)
107
+ #
108
+ # re = rel_error(X, eigs)
109
+ # print(f"max: {re.max().item()}, mean: {re.mean().item()}, min: {re.min().item()}")
110
+ #
111
+ # # print(X[:3, :10])
112
+ # # print(eigs[:10])
113
+ #
114
+ # sum_X = sum_X + X
115
+ # sum_eigs = sum_eigs + eigs
116
+ #
117
+ # # print(torch.diag(Xs[0].mT @ Xs[1]))
118
+ # # raise Exception()
119
+ #
120
+ # print("=" * 120)
121
+ # mean_X, mean_eigs = sum_X / n_trials, sum_eigs / n_trials
122
+ # mean_re = rel_error(mean_X, mean_eigs)
123
+ # print(f"max: {mean_re.max().item()}, mean: {mean_re.mean().item()}, min: {mean_re.min().item()}")
124
+ #
125
+ # raise Exception()
126
+
127
+
128
+
129
+ ncs = [
130
+ NCUT(n_components=n_components, num_sample=n, distance=distance, eig_solver=eig_solver),
131
+ NCUT(n_components=n_components, num_sample=num_sample, distance=distance, eig_solver=eig_solver),
132
+ # OldNCUT(num_eig=n_components, num_sample=num_sample, knn=10, distance=distance, eig_solver=eig_solver, make_orthogonal=True),
133
+ ]
134
+
135
+ for NC in ncs:
136
+ torch.manual_seed(1212)
137
+ np.random.seed(1212)
138
+ X, eigs = NC.fit_transform(M)
139
+
140
+ RE = rel_error(X, eigs)
141
+ print(f"max: {RE.max().item()}, mean: {RE.mean().item()}, min: {RE.min().item()}")
142
+
143
+ # torch.manual_seed(1212)
144
+ # np.random.seed(1212)
145
+ #
146
+ # aX, R = axis_align(X)
147
+ # print(aX[:3])
148
+ # print(R)
149
+ # print(R @ R.mT)
150
+
151
+
152
+
153
+
154
+ # import time
155
+ # n_trials = 10
156
+ #
157
+ # with torch.no_grad():
158
+ # start_t = time.perf_counter()
159
+ # for _ in range(n_trials):
160
+ # X, eigs = NC.fit_transform(M)
161
+ # end_t = time.perf_counter()
162
+ # print(X.min().item(), X.max().item(), eigs)
163
+ # print(f"{1e3 * (end_t - start_t) / n_trials}ms")
164
+ #
165
+ # start_t = time.perf_counter()
166
+ # for _ in range(n_trials):
167
+ # nX, neigs = nNC.fit_transform(M)
168
+ # end_t = time.perf_counter()
169
+ # print(nX.min().item(), nX.max().item(), neigs)
170
+ # print(f"{1e3 * (end_t - start_t) / n_trials}ms")
171
+ # raise Exception()
172
+
173
+ # assert torch.all(torch.isclose(X, torch.Tensor([
174
+ # [0.320216, 0.144101, -0.110744, -0.560543, -0.007982],
175
+ # [0.297634, 0.662867, 0.146107, 0.277893, 0.553959],
176
+ # [0.324994, -0.057295, 0.052916, 0.391666, -0.460911],
177
+ # [0.301703, -0.460709, 0.528563, 0.222525, 0.325546],
178
+ # [0.316614, 0.043475, -0.526899, 0.100665, -0.030259],
179
+ # [0.325425, -0.127884, 0.294540, -0.012173, -0.303528],
180
+ # [0.318136, -0.288952, -0.065148, -0.470192, 0.244805],
181
+ # [0.309522, -0.352693, -0.473237, 0.234057, 0.276185],
182
+ # [0.320464, 0.229301, 0.281134, -0.308938, -0.169746],
183
+ # [0.326147, 0.213536, -0.112246, 0.155114, -0.341439]
184
+ # ]), atol=1e-6)), "Failed assertion"
185
+
186
+ # torch.manual_seed(1212)
187
+ # np.random.seed(1212)
188
+ # X_2d, rgb = rgb_from_umap_sphere(X)
189
+ # # X_3d, rgb = rgb_from_cosine_tsne_3d(X)
190
+ # print(rgb)
@@ -1,111 +0,0 @@
1
- import numpy as np
2
- import torch
3
- import torch.nn.functional as Fn
4
-
5
- from src.nystrom_ncut.ncut_pytorch import NCUT, axis_align
6
- # from ncut_pytorch.src import rgb_from_umap_sphere
7
- # from ncut_pytorch.src.new_ncut_pytorch import NewNCUT
8
-
9
- # from ncut_pytorch.ncut_pytorch.backbone_text import load_text_model
10
-
11
-
12
- if __name__ == "__main__":
13
- # torch.manual_seed(1212)
14
- # M = torch.randn((7, 3))
15
- # W = torch.nn.functional.cosine_similarity(M[:, None], M[None, :], dim=-1)
16
- # A = torch.exp(W - 1)
17
- # D_s2 = torch.sum(A, dim=-1, keepdim=True) ** -0.5
18
- # # print(A)
19
- # print(A * D_s2 * D_s2.mT)
20
- #
21
- # ncut = NCUT(num_eig=7, knn=1, eig_solver="svd")
22
- # V, L = ncut.fit_transform(M)
23
- # print(V @ torch.diag(L) @ V.mT)
24
- # raise Exception()
25
-
26
- # print(load_text_model("meta-llama/Meta-Llama-3.1-8B").cuda())
27
- # print(AutoModelForCausalLM.from_pretrained(
28
- # "meta-llama/Meta-Llama-3.1-8B",
29
- # token="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",
30
- # ))
31
- # # print(transformers.pipeline(
32
- # # "text-generation",
33
- # # model="meta-llama/Meta-Llama-3.1-8B",
34
- # # model_kwargs={"torch_dtype": torch.bfloat16},
35
- # # token="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",
36
- # # device="cpu",
37
- # # ))
38
- # raise Exception(
39
-
40
- torch.set_printoptions(precision=8, sci_mode=False, linewidth=400)
41
- torch.set_default_dtype(torch.float64)
42
- torch.manual_seed(1212)
43
- np.random.seed(1212)
44
-
45
- M = torch.rand((12000, 12))
46
- NC = NCUT(n_components=12, num_sample=10000, sample_method="farthest")
47
-
48
- torch.manual_seed(1212)
49
- np.random.seed(1212)
50
- X, eigs = NC.fit_transform(M)
51
- print(eigs)
52
-
53
- normalized_M = Fn.normalize(M, p=2, dim=-1)
54
- A = torch.exp(-(1 - normalized_M @ normalized_M.mT))
55
- R = torch.diag(torch.sum(A, dim=-1) ** -0.5)
56
- L = R @ A @ R
57
- # print(L)
58
- # print(X @ torch.diag(eigs) @ X.mT)
59
- # print(L)
60
- RE = torch.abs(X @ torch.diag(eigs) @ X.mT / L - 1)
61
- print(RE.max().item(), RE.mean().item())
62
-
63
- # torch.manual_seed(1212)
64
- # np.random.seed(1212)
65
- #
66
- # aX, R = axis_align(X)
67
- # print(aX[:3])
68
- # print(R)
69
- # print(R @ R.mT)
70
- raise Exception()
71
-
72
-
73
-
74
-
75
- # import time
76
- # n_trials = 10
77
- #
78
- # with torch.no_grad():
79
- # start_t = time.perf_counter()
80
- # for _ in range(n_trials):
81
- # X, eigs = NC.fit_transform(M)
82
- # end_t = time.perf_counter()
83
- # print(X.min().item(), X.max().item(), eigs)
84
- # print(f"{1e3 * (end_t - start_t) / n_trials}ms")
85
- #
86
- # start_t = time.perf_counter()
87
- # for _ in range(n_trials):
88
- # nX, neigs = nNC.fit_transform(M)
89
- # end_t = time.perf_counter()
90
- # print(nX.min().item(), nX.max().item(), neigs)
91
- # print(f"{1e3 * (end_t - start_t) / n_trials}ms")
92
- # raise Exception()
93
-
94
- # assert torch.all(torch.isclose(X, torch.Tensor([
95
- # [0.320216, 0.144101, -0.110744, -0.560543, -0.007982],
96
- # [0.297634, 0.662867, 0.146107, 0.277893, 0.553959],
97
- # [0.324994, -0.057295, 0.052916, 0.391666, -0.460911],
98
- # [0.301703, -0.460709, 0.528563, 0.222525, 0.325546],
99
- # [0.316614, 0.043475, -0.526899, 0.100665, -0.030259],
100
- # [0.325425, -0.127884, 0.294540, -0.012173, -0.303528],
101
- # [0.318136, -0.288952, -0.065148, -0.470192, 0.244805],
102
- # [0.309522, -0.352693, -0.473237, 0.234057, 0.276185],
103
- # [0.320464, 0.229301, 0.281134, -0.308938, -0.169746],
104
- # [0.326147, 0.213536, -0.112246, 0.155114, -0.341439]
105
- # ]), atol=1e-6)), "Failed assertion"
106
-
107
- torch.manual_seed(1212)
108
- np.random.seed(1212)
109
- X_2d, rgb = rgb_from_umap_sphere(X)
110
- # X_3d, rgb = rgb_from_cosine_tsne_3d(X)
111
- print(rgb)
File without changes
File without changes
File without changes
File without changes