nystrom-ncut 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -54,10 +54,10 @@ class LaplacianKernel(OnlineKernel):
54
54
  self.A,
55
55
  num_eig=d + 1, # d * (d + 3) // 2 + 1,
56
56
  eig_solver=self.eig_solver,
57
- ) # [... x n x (d + 1)], [... x (d + 1)]
58
- self.Ainv = U @ torch.diag_embed(1 / L) @ U.mT # [... x n x n]
59
- self.a_r = torch.where(self.anchor_mask, torch.inf, torch.sum(self.A, dim=-1)) # [... x n]
60
- self.b_r = torch.zeros_like(self.a_r) # [... x n]
57
+ ) # [... x n x (d + 1)], [... x (d + 1)]
58
+ self.Ainv = U @ torch.nan_to_num(torch.diag_embed(1 / L), posinf=0.0, neginf=0.0) @ U.mT # [... x n x n]
59
+ self.a_r = torch.where(self.anchor_mask, torch.inf, torch.sum(self.A, dim=-1)) # [... x n]
60
+ self.b_r = torch.zeros_like(self.a_r) # [... x n]
61
61
 
62
62
  def _affinity(self, features: torch.Tensor) -> torch.Tensor:
63
63
  B = torch.where(self.anchor_mask[..., None], 0.0, affinity_from_features(
@@ -3,6 +3,9 @@ from typing import Literal
3
3
  import torch
4
4
  import torch.nn.functional as Fn
5
5
 
6
+ from ..common import (
7
+ default_device,
8
+ )
6
9
  from .transformer_mixin import (
7
10
  TorchTransformerMixin,
8
11
  )
@@ -27,50 +30,59 @@ class AxisAlign(TorchTransformerMixin):
27
30
 
28
31
  def fit(self, X: torch.Tensor) -> "AxisAlign":
29
32
  # Normalize eigenvectors
30
- n, d = X.shape
31
- normalized_X = Fn.normalize(X, p=2, dim=-1)
32
-
33
- # Initialize R matrix with the first column from a random row of EigenVectors
34
- self.R = torch.empty((d, d), device=X.device)
35
- self.R[0] = normalized_X[torch.randint(0, n, (), device=X.device)]
36
-
37
- # Loop to populate R with k orthogonal directions
38
- c = torch.zeros((n,), device=X.device)
39
- for i in range(1, d):
40
- c += torch.abs(normalized_X @ self.R[i - 1])
41
- self.R[i] = normalized_X[torch.argmin(c, dim=0)]
42
-
43
- # Iterative optimization loop
44
- idx, prev_objective = None, torch.inf
45
- for _ in range(self.max_iter):
46
- # Discretize the projected eigenvectors
47
- idx = torch.argmax(normalized_X @ self.R.mT, dim=-1)
48
- M = torch.zeros((d, d), device=X.device).index_add_(0, idx, normalized_X)
49
-
50
- # Check for convergence
51
- objective = torch.norm(M)
52
- if torch.abs(objective - prev_objective) < torch.finfo(torch.float32).eps:
53
- break
54
- prev_objective = objective
55
-
56
- # SVD decomposition to compute the next R
57
- U, S, Vh = torch.linalg.svd(M, full_matrices=False)
58
- self.R = U @ Vh
59
-
60
- # Permute the rotation matrix so the dimensions are sorted in descending cluster significance
61
- if self.sort_method == "count":
62
- sort_metric = torch.bincount(idx, minlength=d)
63
- elif self.sort_method == "norm":
64
- rotated_X = X @ self.R.mT
65
- sort_metric = torch.linalg.norm(rotated_X, dim=0)
66
- elif self.sort_method == "marginal_norm":
67
- rotated_X = X @ self.R.mT
68
- sort_metric = torch.zeros((d,), device=X.device).index_add_(0, idx, rotated_X[range(n), idx] ** 2)
69
- else:
70
- raise ValueError(f"Invalid sort method {self.sort_method}.")
71
-
72
- self.R = self.R[torch.argsort(sort_metric, dim=0, descending=True)]
73
- return self
33
+ with default_device(X.device):
34
+ d = X.shape[-1]
35
+ normalized_X = Fn.normalize(X, p=2, dim=-1) # float: [... x n x d]
36
+
37
+ # Initialize R matrix with the first column from a random row of EigenVectors
38
+ def get_idx(idx: torch.Tensor) -> torch.Tensor:
39
+ return torch.gather(normalized_X, -2, idx[..., None, None].expand([-1] * (X.ndim - 2) + [1, d]))[..., 0, :]
40
+
41
+ self.R = torch.empty((*X.shape[:-2], d, d)) # float: [... x d x d]
42
+ mask = torch.all(torch.isfinite(normalized_X), dim=-1) # bool: [... x n]
43
+ start_idx = torch.argmax(mask.to(torch.float) + torch.rand(mask.shape), dim=-1) # int: [...]
44
+ self.R[..., 0, :] = get_idx(start_idx)
45
+
46
+ # Loop to populate R with k orthogonal directions
47
+ c = torch.zeros(X.shape[:-1]) # float: [... x n]
48
+ for i in range(1, d):
49
+ c += torch.abs(normalized_X @ self.R[..., i - 1, :, None])[..., 0]
50
+ self.R[..., i, :] = get_idx(torch.argmin(c.nan_to_num(nan=torch.inf), dim=-1))
51
+
52
+ # Iterative optimization loop
53
+ normalized_X = torch.nan_to_num(normalized_X, nan=0.0)
54
+ idx, prev_objective = None, torch.inf
55
+ for _ in range(self.max_iter):
56
+ # Discretize the projected eigenvectors
57
+ idx = torch.argmax(normalized_X @ self.R.mT, dim=-1) # int: [... x n]
58
+ M = torch.sum((idx[..., None] == torch.arange(d))[..., None] * normalized_X[..., :, None, :], dim=-3) # float: [... x d x d]
59
+
60
+ # Check for convergence
61
+ objective = torch.norm(M)
62
+ if torch.abs(objective - prev_objective) < torch.finfo(torch.float32).eps:
63
+ break
64
+ prev_objective = objective
65
+
66
+ # SVD decomposition to compute the next R
67
+ U, S, Vh = torch.linalg.svd(M, full_matrices=False)
68
+ self.R = U @ Vh
69
+
70
+ # Permute the rotation matrix so the dimensions are sorted in descending cluster significance
71
+ match self.sort_method:
72
+ case "count":
73
+ sort_metric = torch.sum((idx[..., None] == torch.arange(d)), dim=-2)
74
+ case "norm":
75
+ rotated_X = torch.nan_to_num(X @ self.R.mT, nan=0.0)
76
+ sort_metric = torch.linalg.norm(rotated_X, dim=-2)
77
+ case "marginal_norm":
78
+ rotated_X = torch.nan_to_num(X @ self.R.mT, nan=0.0)
79
+ sort_metric = torch.sum((idx[..., None] == torch.arange(d)) * (torch.gather(rotated_X, -1, idx[..., None]) ** 2), dim=-2)
80
+ case _:
81
+ raise ValueError(f"Invalid sort method {self.sort_method}.")
82
+
83
+ order = torch.argsort(sort_metric, dim=-1, descending=True)
84
+ self.R = torch.gather(self.R, -2, order[..., None].expand([-1] * order.ndim + [d]))
85
+ return self
74
86
 
75
87
  def transform(self, X: torch.Tensor, normalize: bool = True, hard: bool = False) -> torch.Tensor:
76
88
  """
@@ -82,9 +94,9 @@ class AxisAlign(TorchTransformerMixin):
82
94
  torch.Tensor: Discretized eigenvectors, shape (n, k), each row is a one-hot vector.
83
95
  """
84
96
  if normalize:
85
- X = Fn.normalize(X, p=2, dim=1)
97
+ X = Fn.normalize(X, p=2, dim=-1)
86
98
  rotated_X = X @ self.R.mT
87
- return torch.argmax(rotated_X, dim=1) if hard else rotated_X
99
+ return torch.argmax(rotated_X, dim=-1) if hard else rotated_X
88
100
 
89
101
  def fit_transform(self, X: torch.Tensor, normalize: bool = True, hard: bool = False) -> torch.Tensor:
90
102
  return self.fit(X).transform(X, normalize=normalize, hard=hard)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: nystrom_ncut
3
- Version: 0.3.0
3
+ Version: 0.3.2
4
4
  Summary: Normalized Cut and Nyström Approximation
5
5
  Author-email: Huzheng Yang <huze.yann@gmail.com>, Wentinn Liao <wentinn.liao@gmail.com>
6
6
  Project-URL: Documentation, https://github.com/JophiArcana/Nystrom-NCUT/
@@ -6,13 +6,13 @@ nystrom_ncut/sampling_utils.py,sha256=6lP8F6gftl4mgkavPsD7Vuk4erj4RtgILPhcj3YqLX
6
6
  nystrom_ncut/visualize_utils.py,sha256=Sfi_kKpvFFzBFoJnbo-pQpH2jhs-A6tH64SV_WGoq58,22740
7
7
  nystrom_ncut/nystrom/__init__.py,sha256=1aUXK87g4cXRXqNt6XkZsfyauw1-yv3sv0NmdmkWo-8,42
8
8
  nystrom_ncut/nystrom/distance_realization.py,sha256=RTI1_Q8fCUGAPSbXaVuNA-2B-11CEAfy2CwKWPJj6xQ,5830
9
- nystrom_ncut/nystrom/normalized_cut.py,sha256=jB_QALMY3l5CFfZPsrOFpEaquTrJP17muTrDZXxzUA8,7177
9
+ nystrom_ncut/nystrom/normalized_cut.py,sha256=cjkG8JeDmTPDK8KwfkAIqF9f1dI-D9s1muJ9WWZlUoc,7237
10
10
  nystrom_ncut/nystrom/nystrom_utils.py,sha256=hksDO8uuAb9xKoA1ZafGwXDlQN_gZJn_qHscaSoO8JE,14120
11
11
  nystrom_ncut/transformer/__init__.py,sha256=jjXjcNp3LrxeF6mqG9VY5k3asrqaY6bXzJz6wTpH78Q,105
12
- nystrom_ncut/transformer/axis_align.py,sha256=6LTR-syJ-f4pcbnMexFmFNn1QADDhH5gka6979YBRrI,3549
12
+ nystrom_ncut/transformer/axis_align.py,sha256=j3LlAPrp8O_jQAlwZz-gu3D7n_wICEJranye-YK5wvA,4880
13
13
  nystrom_ncut/transformer/transformer_mixin.py,sha256=YAjrDWTL5Hjnk9J2OsoxvtwT2N0u8IdgMSx0rRFmZzE,1653
14
- nystrom_ncut-0.3.0.dist-info/LICENSE,sha256=2bm9uFabQZ3Ykb_SaSU_uUbAj2-htc6WJQmS_65qD00,1073
15
- nystrom_ncut-0.3.0.dist-info/METADATA,sha256=lhxicufu5Eo9HQsUiS_K-CzocemOeNravAaIXeCtriM,6058
16
- nystrom_ncut-0.3.0.dist-info/WHEEL,sha256=52BFRY2Up02UkjOa29eZOS2VxUrpPORXg1pkohGGUS8,91
17
- nystrom_ncut-0.3.0.dist-info/top_level.txt,sha256=gM8IWWHYysIRTCvCTcdS4RShOyl9pxpylgSwPUZR2XM,22
18
- nystrom_ncut-0.3.0.dist-info/RECORD,,
14
+ nystrom_ncut-0.3.2.dist-info/LICENSE,sha256=2bm9uFabQZ3Ykb_SaSU_uUbAj2-htc6WJQmS_65qD00,1073
15
+ nystrom_ncut-0.3.2.dist-info/METADATA,sha256=4E42fHnXLNvbErWrwxE5K_oeOdpi5Bfabed9VF-YkV0,6058
16
+ nystrom_ncut-0.3.2.dist-info/WHEEL,sha256=52BFRY2Up02UkjOa29eZOS2VxUrpPORXg1pkohGGUS8,91
17
+ nystrom_ncut-0.3.2.dist-info/top_level.txt,sha256=gM8IWWHYysIRTCvCTcdS4RShOyl9pxpylgSwPUZR2XM,22
18
+ nystrom_ncut-0.3.2.dist-info/RECORD,,