nystrom-ncut 0.0.3__tar.gz → 0.0.5__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nystrom_ncut
3
- Version: 0.0.3
3
+ Version: 0.0.5
4
4
  Summary: Normalized Cut and Nyström Approximation
5
5
  Author-email: Huzheng Yang <huze.yann@gmail.com>, Wentinn Liao <wentinn.liao@gmail.com>
6
6
  Project-URL: Documentation, https://github.com/JophiArcana/Nystrom-NCUT/
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "nystrom_ncut"
7
- version = "0.0.3"
7
+ version = "0.0.5"
8
8
  authors = [
9
9
  { name = "Huzheng Yang", email = "huze.yann@gmail.com" },
10
10
  { name = "Wentinn Liao", email = "wentinn.liao@gmail.com" },
@@ -173,9 +173,10 @@ class NCUT(OnlineNystrom):
173
173
  else:
174
174
  sampled_indices = run_subgraph_sampling(
175
175
  features,
176
- num_sample=self.num_sample,
176
+ self.num_sample,
177
177
  sample_method=self.sample_method,
178
178
  )
179
+ sampled_indices = torch.sort(sampled_indices).values
179
180
  sampled_features = features[sampled_indices]
180
181
  OnlineNystrom.fit(self, sampled_features)
181
182
 
@@ -52,6 +52,18 @@ class OnlineNystrom:
52
52
  self.transform_matrix: torch.Tensor = None # [n x n_components]
53
53
  self.LS: torch.Tensor = None # [n]
54
54
 
55
+ def _update_to_kernel(self) -> Tuple[torch.Tensor, torch.Tensor]:
56
+ self.A = self.S = self.kernel.transform()
57
+ U, L = solve_eig(
58
+ self.A,
59
+ num_eig=self.inverse_approximation_dim,
60
+ eig_solver=self.eig_solver,
61
+ ) # [n x (? + 1)], [? + 1]
62
+ self.Ahinv_UL = U * (L ** -0.5) # [n x (? + 1)]
63
+ self.Ahinv_VT = U.mT # [(? + 1) x n]
64
+ self.Ahinv = self.Ahinv_UL @ self.Ahinv_VT # [n x n]
65
+ return U, L
66
+
55
67
  def fit(self, features: torch.Tensor):
56
68
  OnlineNystrom.fit_transform(self, features)
57
69
  return self
@@ -60,17 +72,8 @@ class OnlineNystrom:
60
72
  self.anchor_features = features
61
73
 
62
74
  self.kernel.fit(self.anchor_features)
63
- self.A = self.S = self.kernel.transform() # [n x n]
64
-
65
75
  self.inverse_approximation_dim = max(self.n_components, features.shape[-1]) + 1
66
- U, L = solve_eig(
67
- self.A,
68
- num_eig=self.inverse_approximation_dim,
69
- eig_solver=self.eig_solver,
70
- ) # [n x (? + 1)], [? + 1]
71
- self.Ahinv_UL = U * (L ** -0.5) # [n x (? + 1)]
72
- self.Ahinv_VT = U.mT # [(? + 1) x n]
73
- self.Ahinv = self.Ahinv_UL @ self.Ahinv_VT # [n x n]
76
+ U, L = self._update_to_kernel() # [n x (? + 1)], [? + 1]
74
77
 
75
78
  self.transform_matrix = (U / L)[:, :self.n_components] # [n x n_components]
76
79
  self.LS = L[:self.n_components] # [n_components]
@@ -83,6 +86,7 @@ class OnlineNystrom:
83
86
  chunks = torch.chunk(features, n_chunks, dim=0)
84
87
  for chunk in chunks:
85
88
  self.kernel.update(chunk)
89
+ self._update_to_kernel()
86
90
 
87
91
  compressed_BBT = torch.zeros((self.inverse_approximation_dim, self.inverse_approximation_dim)) # [(? + 1) x (? + 1))]
88
92
  for i, chunk in enumerate(chunks):
@@ -101,6 +105,7 @@ class OnlineNystrom:
101
105
  else:
102
106
  """ Unchunked version """
103
107
  B = self.kernel.update(features).mT # [n x m]
108
+ self._update_to_kernel()
104
109
  compressed_B = self.Ahinv_VT @ B # [indirect_pca_dim x m]
105
110
 
106
111
  self.S = self.S + self.Ahinv_UL @ (compressed_B @ compressed_B.mT) @ self.Ahinv_UL.mT # [n x n]
@@ -11,7 +11,7 @@ from .common import ceildiv, lazy_normalize
11
11
  @torch.no_grad()
12
12
  def run_subgraph_sampling(
13
13
  features: torch.Tensor,
14
- num_sample: int = 300,
14
+ num_sample: int,
15
15
  max_draw: int = 1000000,
16
16
  sample_method: Literal["farthest", "random"] = "farthest",
17
17
  ):
@@ -272,7 +272,7 @@ def propagate_eigenvectors(
272
272
  # sample subgraph
273
273
  subgraph_indices = run_subgraph_sampling(
274
274
  features,
275
- num_sample=num_sample,
275
+ num_sample,
276
276
  sample_method=sample_method,
277
277
  )
278
278
 
@@ -34,7 +34,7 @@ def _rgb_with_dimensionality_reduction(
34
34
  ) -> Tuple[torch.Tensor, torch.Tensor]:
35
35
  subgraph_indices = run_subgraph_sampling(
36
36
  features,
37
- num_sample=num_sample,
37
+ num_sample,
38
38
  sample_method="farthest",
39
39
  )
40
40
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nystrom_ncut
3
- Version: 0.0.3
3
+ Version: 0.0.5
4
4
  Summary: Normalized Cut and Nyström Approximation
5
5
  Author-email: Huzheng Yang <huze.yann@gmail.com>, Wentinn Liao <wentinn.liao@gmail.com>
6
6
  Project-URL: Documentation, https://github.com/JophiArcana/Nystrom-NCUT/
@@ -1,5 +1,7 @@
1
1
  import numpy as np
2
2
  import torch
3
+ import torch.nn.functional as Fn
4
+
3
5
  from src.nystrom_ncut.ncut_pytorch import NCUT, axis_align
4
6
  # from ncut_pytorch.src import rgb_from_umap_sphere
5
7
  # from ncut_pytorch.src.new_ncut_pytorch import NewNCUT
@@ -35,28 +37,36 @@ if __name__ == "__main__":
35
37
  # # ))
36
38
  # raise Exception(
37
39
 
38
- torch.set_printoptions(precision=12, sci_mode=False, linewidth=400)
40
+ torch.set_printoptions(precision=8, sci_mode=False, linewidth=400)
41
+ torch.set_default_dtype(torch.float64)
39
42
  torch.manual_seed(1212)
40
43
  np.random.seed(1212)
41
44
 
42
- M = torch.rand((10000, 12))
43
- # NC = NCUT(num_eig=5, knn=None, verbose=True)
44
- kwargs = dict(num_eig=7, sample_method="random")
45
- nNC = NCUT(**kwargs)
45
+ M = torch.rand((12000, 12))
46
+ NC = NCUT(n_components=12, num_sample=10000, sample_method="farthest")
46
47
 
47
48
  torch.manual_seed(1212)
48
49
  np.random.seed(1212)
49
- nX, neigs = nNC.fit_transform(M)
50
- # print(neigs)
51
- # print(nX.mT @ nX)
50
+ X, eigs = NC.fit_transform(M)
51
+ print(eigs)
52
52
 
53
- torch.manual_seed(1212)
54
- np.random.seed(1212)
53
+ normalized_M = Fn.normalize(M, p=2, dim=-1)
54
+ A = torch.exp(-(1 - normalized_M @ normalized_M.mT))
55
+ R = torch.diag(torch.sum(A, dim=-1) ** -0.5)
56
+ L = R @ A @ R
57
+ # print(L)
58
+ # print(X @ torch.diag(eigs) @ X.mT)
59
+ # print(L)
60
+ RE = torch.abs(X @ torch.diag(eigs) @ X.mT / L - 1)
61
+ print(RE.max().item(), RE.mean().item())
55
62
 
56
- aX, R = axis_align(nX)
57
- print(aX[:3])
58
- print(R)
59
- print(R @ R.mT)
63
+ # torch.manual_seed(1212)
64
+ # np.random.seed(1212)
65
+ #
66
+ # aX, R = axis_align(X)
67
+ # print(aX[:3])
68
+ # print(R)
69
+ # print(R @ R.mT)
60
70
  raise Exception()
61
71
 
62
72
 
File without changes
File without changes
File without changes
File without changes