nystrom-ncut 0.3.0__tar.gz → 0.3.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (26) hide show
  1. {nystrom_ncut-0.3.0/src/nystrom_ncut.egg-info → nystrom_ncut-0.3.2}/PKG-INFO +1 -1
  2. {nystrom_ncut-0.3.0 → nystrom_ncut-0.3.2}/pyproject.toml +1 -1
  3. {nystrom_ncut-0.3.0 → nystrom_ncut-0.3.2}/src/nystrom_ncut/nystrom/normalized_cut.py +4 -4
  4. nystrom_ncut-0.3.2/src/nystrom_ncut/transformer/axis_align.py +102 -0
  5. {nystrom_ncut-0.3.0 → nystrom_ncut-0.3.2/src/nystrom_ncut.egg-info}/PKG-INFO +1 -1
  6. nystrom_ncut-0.3.0/src/nystrom_ncut/transformer/axis_align.py +0 -90
  7. {nystrom_ncut-0.3.0 → nystrom_ncut-0.3.2}/LICENSE +0 -0
  8. {nystrom_ncut-0.3.0 → nystrom_ncut-0.3.2}/MANIFEST.in +0 -0
  9. {nystrom_ncut-0.3.0 → nystrom_ncut-0.3.2}/README.md +0 -0
  10. {nystrom_ncut-0.3.0 → nystrom_ncut-0.3.2}/requirements.txt +0 -0
  11. {nystrom_ncut-0.3.0 → nystrom_ncut-0.3.2}/setup.cfg +0 -0
  12. {nystrom_ncut-0.3.0 → nystrom_ncut-0.3.2}/src/__init__.py +0 -0
  13. {nystrom_ncut-0.3.0 → nystrom_ncut-0.3.2}/src/nystrom_ncut/__init__.py +0 -0
  14. {nystrom_ncut-0.3.0 → nystrom_ncut-0.3.2}/src/nystrom_ncut/common.py +0 -0
  15. {nystrom_ncut-0.3.0 → nystrom_ncut-0.3.2}/src/nystrom_ncut/distance_utils.py +0 -0
  16. {nystrom_ncut-0.3.0 → nystrom_ncut-0.3.2}/src/nystrom_ncut/nystrom/__init__.py +0 -0
  17. {nystrom_ncut-0.3.0 → nystrom_ncut-0.3.2}/src/nystrom_ncut/nystrom/distance_realization.py +0 -0
  18. {nystrom_ncut-0.3.0 → nystrom_ncut-0.3.2}/src/nystrom_ncut/nystrom/nystrom_utils.py +0 -0
  19. {nystrom_ncut-0.3.0 → nystrom_ncut-0.3.2}/src/nystrom_ncut/sampling_utils.py +0 -0
  20. {nystrom_ncut-0.3.0 → nystrom_ncut-0.3.2}/src/nystrom_ncut/transformer/__init__.py +0 -0
  21. {nystrom_ncut-0.3.0 → nystrom_ncut-0.3.2}/src/nystrom_ncut/transformer/transformer_mixin.py +0 -0
  22. {nystrom_ncut-0.3.0 → nystrom_ncut-0.3.2}/src/nystrom_ncut/visualize_utils.py +0 -0
  23. {nystrom_ncut-0.3.0 → nystrom_ncut-0.3.2}/src/nystrom_ncut.egg-info/SOURCES.txt +0 -0
  24. {nystrom_ncut-0.3.0 → nystrom_ncut-0.3.2}/src/nystrom_ncut.egg-info/dependency_links.txt +0 -0
  25. {nystrom_ncut-0.3.0 → nystrom_ncut-0.3.2}/src/nystrom_ncut.egg-info/top_level.txt +0 -0
  26. {nystrom_ncut-0.3.0 → nystrom_ncut-0.3.2}/tests/test.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: nystrom_ncut
3
- Version: 0.3.0
3
+ Version: 0.3.2
4
4
  Summary: Normalized Cut and Nyström Approximation
5
5
  Author-email: Huzheng Yang <huze.yann@gmail.com>, Wentinn Liao <wentinn.liao@gmail.com>
6
6
  Project-URL: Documentation, https://github.com/JophiArcana/Nystrom-NCUT/
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "nystrom_ncut"
7
- version = "0.3.0"
7
+ version = "0.3.2"
8
8
  authors = [
9
9
  { name = "Huzheng Yang", email = "huze.yann@gmail.com" },
10
10
  { name = "Wentinn Liao", email = "wentinn.liao@gmail.com" },
@@ -54,10 +54,10 @@ class LaplacianKernel(OnlineKernel):
54
54
  self.A,
55
55
  num_eig=d + 1, # d * (d + 3) // 2 + 1,
56
56
  eig_solver=self.eig_solver,
57
- ) # [... x n x (d + 1)], [... x (d + 1)]
58
- self.Ainv = U @ torch.diag_embed(1 / L) @ U.mT # [... x n x n]
59
- self.a_r = torch.where(self.anchor_mask, torch.inf, torch.sum(self.A, dim=-1)) # [... x n]
60
- self.b_r = torch.zeros_like(self.a_r) # [... x n]
57
+ ) # [... x n x (d + 1)], [... x (d + 1)]
58
+ self.Ainv = U @ torch.nan_to_num(torch.diag_embed(1 / L), posinf=0.0, neginf=0.0) @ U.mT # [... x n x n]
59
+ self.a_r = torch.where(self.anchor_mask, torch.inf, torch.sum(self.A, dim=-1)) # [... x n]
60
+ self.b_r = torch.zeros_like(self.a_r) # [... x n]
61
61
 
62
62
  def _affinity(self, features: torch.Tensor) -> torch.Tensor:
63
63
  B = torch.where(self.anchor_mask[..., None], 0.0, affinity_from_features(
@@ -0,0 +1,102 @@
1
+ from typing import Literal
2
+
3
+ import torch
4
+ import torch.nn.functional as Fn
5
+
6
+ from ..common import (
7
+ default_device,
8
+ )
9
+ from .transformer_mixin import (
10
+ TorchTransformerMixin,
11
+ )
12
+
13
+
14
+ class AxisAlign(TorchTransformerMixin):
15
+ """Multiclass Spectral Clustering, SX Yu, J Shi, 2003
16
+ Args:
17
+ max_iter (int, optional): Maximum number of iterations.
18
+ """
19
+ SortOptions = Literal["count", "norm", "marginal_norm"]
20
+
21
+ def __init__(
22
+ self,
23
+ sort_method: SortOptions = "norm",
24
+ max_iter: int = 100,
25
+ ):
26
+ self.sort_method: AxisAlign.SortOptions = sort_method
27
+ self.max_iter: int = max_iter
28
+
29
+ self.R: torch.Tensor = None
30
+
31
+ def fit(self, X: torch.Tensor) -> "AxisAlign":
32
+ # Normalize eigenvectors
33
+ with default_device(X.device):
34
+ d = X.shape[-1]
35
+ normalized_X = Fn.normalize(X, p=2, dim=-1) # float: [... x n x d]
36
+
37
+ # Initialize R matrix with the first column from a random row of EigenVectors
38
+ def get_idx(idx: torch.Tensor) -> torch.Tensor:
39
+ return torch.gather(normalized_X, -2, idx[..., None, None].expand([-1] * (X.ndim - 2) + [1, d]))[..., 0, :]
40
+
41
+ self.R = torch.empty((*X.shape[:-2], d, d)) # float: [... x d x d]
42
+ mask = torch.all(torch.isfinite(normalized_X), dim=-1) # bool: [... x n]
43
+ start_idx = torch.argmax(mask.to(torch.float) + torch.rand(mask.shape), dim=-1) # int: [...]
44
+ self.R[..., 0, :] = get_idx(start_idx)
45
+
46
+ # Loop to populate R with k orthogonal directions
47
+ c = torch.zeros(X.shape[:-1]) # float: [... x n]
48
+ for i in range(1, d):
49
+ c += torch.abs(normalized_X @ self.R[..., i - 1, :, None])[..., 0]
50
+ self.R[..., i, :] = get_idx(torch.argmin(c.nan_to_num(nan=torch.inf), dim=-1))
51
+
52
+ # Iterative optimization loop
53
+ normalized_X = torch.nan_to_num(normalized_X, nan=0.0)
54
+ idx, prev_objective = None, torch.inf
55
+ for _ in range(self.max_iter):
56
+ # Discretize the projected eigenvectors
57
+ idx = torch.argmax(normalized_X @ self.R.mT, dim=-1) # int: [... x n]
58
+ M = torch.sum((idx[..., None] == torch.arange(d))[..., None] * normalized_X[..., :, None, :], dim=-3) # float: [... x d x d]
59
+
60
+ # Check for convergence
61
+ objective = torch.norm(M)
62
+ if torch.abs(objective - prev_objective) < torch.finfo(torch.float32).eps:
63
+ break
64
+ prev_objective = objective
65
+
66
+ # SVD decomposition to compute the next R
67
+ U, S, Vh = torch.linalg.svd(M, full_matrices=False)
68
+ self.R = U @ Vh
69
+
70
+ # Permute the rotation matrix so the dimensions are sorted in descending cluster significance
71
+ match self.sort_method:
72
+ case "count":
73
+ sort_metric = torch.sum((idx[..., None] == torch.arange(d)), dim=-2)
74
+ case "norm":
75
+ rotated_X = torch.nan_to_num(X @ self.R.mT, nan=0.0)
76
+ sort_metric = torch.linalg.norm(rotated_X, dim=-2)
77
+ case "marginal_norm":
78
+ rotated_X = torch.nan_to_num(X @ self.R.mT, nan=0.0)
79
+ sort_metric = torch.sum((idx[..., None] == torch.arange(d)) * (torch.gather(rotated_X, -1, idx[..., None]) ** 2), dim=-2)
80
+ case _:
81
+ raise ValueError(f"Invalid sort method {self.sort_method}.")
82
+
83
+ order = torch.argsort(sort_metric, dim=-1, descending=True)
84
+ self.R = torch.gather(self.R, -2, order[..., None].expand([-1] * order.ndim + [d]))
85
+ return self
86
+
87
+ def transform(self, X: torch.Tensor, normalize: bool = True, hard: bool = False) -> torch.Tensor:
88
+ """
89
+ Args:
90
+ X (torch.Tensor): continuous eigenvectors from NCUT, shape (n, k)
91
+ normalize (bool): whether to normalize input features before rotating
92
+ hard (bool): whether to return cluster indices of input features or just the rotated features
93
+ Returns:
94
+ torch.Tensor: Discretized eigenvectors, shape (n, k), each row is a one-hot vector.
95
+ """
96
+ if normalize:
97
+ X = Fn.normalize(X, p=2, dim=-1)
98
+ rotated_X = X @ self.R.mT
99
+ return torch.argmax(rotated_X, dim=-1) if hard else rotated_X
100
+
101
+ def fit_transform(self, X: torch.Tensor, normalize: bool = True, hard: bool = False) -> torch.Tensor:
102
+ return self.fit(X).transform(X, normalize=normalize, hard=hard)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: nystrom_ncut
3
- Version: 0.3.0
3
+ Version: 0.3.2
4
4
  Summary: Normalized Cut and Nyström Approximation
5
5
  Author-email: Huzheng Yang <huze.yann@gmail.com>, Wentinn Liao <wentinn.liao@gmail.com>
6
6
  Project-URL: Documentation, https://github.com/JophiArcana/Nystrom-NCUT/
@@ -1,90 +0,0 @@
1
- from typing import Literal
2
-
3
- import torch
4
- import torch.nn.functional as Fn
5
-
6
- from .transformer_mixin import (
7
- TorchTransformerMixin,
8
- )
9
-
10
-
11
- class AxisAlign(TorchTransformerMixin):
12
- """Multiclass Spectral Clustering, SX Yu, J Shi, 2003
13
- Args:
14
- max_iter (int, optional): Maximum number of iterations.
15
- """
16
- SortOptions = Literal["count", "norm", "marginal_norm"]
17
-
18
- def __init__(
19
- self,
20
- sort_method: SortOptions = "norm",
21
- max_iter: int = 100,
22
- ):
23
- self.sort_method: AxisAlign.SortOptions = sort_method
24
- self.max_iter: int = max_iter
25
-
26
- self.R: torch.Tensor = None
27
-
28
- def fit(self, X: torch.Tensor) -> "AxisAlign":
29
- # Normalize eigenvectors
30
- n, d = X.shape
31
- normalized_X = Fn.normalize(X, p=2, dim=-1)
32
-
33
- # Initialize R matrix with the first column from a random row of EigenVectors
34
- self.R = torch.empty((d, d), device=X.device)
35
- self.R[0] = normalized_X[torch.randint(0, n, (), device=X.device)]
36
-
37
- # Loop to populate R with k orthogonal directions
38
- c = torch.zeros((n,), device=X.device)
39
- for i in range(1, d):
40
- c += torch.abs(normalized_X @ self.R[i - 1])
41
- self.R[i] = normalized_X[torch.argmin(c, dim=0)]
42
-
43
- # Iterative optimization loop
44
- idx, prev_objective = None, torch.inf
45
- for _ in range(self.max_iter):
46
- # Discretize the projected eigenvectors
47
- idx = torch.argmax(normalized_X @ self.R.mT, dim=-1)
48
- M = torch.zeros((d, d), device=X.device).index_add_(0, idx, normalized_X)
49
-
50
- # Check for convergence
51
- objective = torch.norm(M)
52
- if torch.abs(objective - prev_objective) < torch.finfo(torch.float32).eps:
53
- break
54
- prev_objective = objective
55
-
56
- # SVD decomposition to compute the next R
57
- U, S, Vh = torch.linalg.svd(M, full_matrices=False)
58
- self.R = U @ Vh
59
-
60
- # Permute the rotation matrix so the dimensions are sorted in descending cluster significance
61
- if self.sort_method == "count":
62
- sort_metric = torch.bincount(idx, minlength=d)
63
- elif self.sort_method == "norm":
64
- rotated_X = X @ self.R.mT
65
- sort_metric = torch.linalg.norm(rotated_X, dim=0)
66
- elif self.sort_method == "marginal_norm":
67
- rotated_X = X @ self.R.mT
68
- sort_metric = torch.zeros((d,), device=X.device).index_add_(0, idx, rotated_X[range(n), idx] ** 2)
69
- else:
70
- raise ValueError(f"Invalid sort method {self.sort_method}.")
71
-
72
- self.R = self.R[torch.argsort(sort_metric, dim=0, descending=True)]
73
- return self
74
-
75
- def transform(self, X: torch.Tensor, normalize: bool = True, hard: bool = False) -> torch.Tensor:
76
- """
77
- Args:
78
- X (torch.Tensor): continuous eigenvectors from NCUT, shape (n, k)
79
- normalize (bool): whether to normalize input features before rotating
80
- hard (bool): whether to return cluster indices of input features or just the rotated features
81
- Returns:
82
- torch.Tensor: Discretized eigenvectors, shape (n, k), each row is a one-hot vector.
83
- """
84
- if normalize:
85
- X = Fn.normalize(X, p=2, dim=1)
86
- rotated_X = X @ self.R.mT
87
- return torch.argmax(rotated_X, dim=1) if hard else rotated_X
88
-
89
- def fit_transform(self, X: torch.Tensor, normalize: bool = True, hard: bool = False) -> torch.Tensor:
90
- return self.fit(X).transform(X, normalize=normalize, hard=hard)
File without changes
File without changes
File without changes
File without changes
File without changes