nystrom-ncut 0.3.0__tar.gz → 0.3.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {nystrom_ncut-0.3.0/src/nystrom_ncut.egg-info → nystrom_ncut-0.3.1}/PKG-INFO +1 -1
- {nystrom_ncut-0.3.0 → nystrom_ncut-0.3.1}/pyproject.toml +1 -1
- nystrom_ncut-0.3.1/src/nystrom_ncut/transformer/axis_align.py +102 -0
- {nystrom_ncut-0.3.0 → nystrom_ncut-0.3.1/src/nystrom_ncut.egg-info}/PKG-INFO +1 -1
- nystrom_ncut-0.3.0/src/nystrom_ncut/transformer/axis_align.py +0 -90
- {nystrom_ncut-0.3.0 → nystrom_ncut-0.3.1}/LICENSE +0 -0
- {nystrom_ncut-0.3.0 → nystrom_ncut-0.3.1}/MANIFEST.in +0 -0
- {nystrom_ncut-0.3.0 → nystrom_ncut-0.3.1}/README.md +0 -0
- {nystrom_ncut-0.3.0 → nystrom_ncut-0.3.1}/requirements.txt +0 -0
- {nystrom_ncut-0.3.0 → nystrom_ncut-0.3.1}/setup.cfg +0 -0
- {nystrom_ncut-0.3.0 → nystrom_ncut-0.3.1}/src/__init__.py +0 -0
- {nystrom_ncut-0.3.0 → nystrom_ncut-0.3.1}/src/nystrom_ncut/__init__.py +0 -0
- {nystrom_ncut-0.3.0 → nystrom_ncut-0.3.1}/src/nystrom_ncut/common.py +0 -0
- {nystrom_ncut-0.3.0 → nystrom_ncut-0.3.1}/src/nystrom_ncut/distance_utils.py +0 -0
- {nystrom_ncut-0.3.0 → nystrom_ncut-0.3.1}/src/nystrom_ncut/nystrom/__init__.py +0 -0
- {nystrom_ncut-0.3.0 → nystrom_ncut-0.3.1}/src/nystrom_ncut/nystrom/distance_realization.py +0 -0
- {nystrom_ncut-0.3.0 → nystrom_ncut-0.3.1}/src/nystrom_ncut/nystrom/normalized_cut.py +0 -0
- {nystrom_ncut-0.3.0 → nystrom_ncut-0.3.1}/src/nystrom_ncut/nystrom/nystrom_utils.py +0 -0
- {nystrom_ncut-0.3.0 → nystrom_ncut-0.3.1}/src/nystrom_ncut/sampling_utils.py +0 -0
- {nystrom_ncut-0.3.0 → nystrom_ncut-0.3.1}/src/nystrom_ncut/transformer/__init__.py +0 -0
- {nystrom_ncut-0.3.0 → nystrom_ncut-0.3.1}/src/nystrom_ncut/transformer/transformer_mixin.py +0 -0
- {nystrom_ncut-0.3.0 → nystrom_ncut-0.3.1}/src/nystrom_ncut/visualize_utils.py +0 -0
- {nystrom_ncut-0.3.0 → nystrom_ncut-0.3.1}/src/nystrom_ncut.egg-info/SOURCES.txt +0 -0
- {nystrom_ncut-0.3.0 → nystrom_ncut-0.3.1}/src/nystrom_ncut.egg-info/dependency_links.txt +0 -0
- {nystrom_ncut-0.3.0 → nystrom_ncut-0.3.1}/src/nystrom_ncut.egg-info/top_level.txt +0 -0
- {nystrom_ncut-0.3.0 → nystrom_ncut-0.3.1}/tests/test.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.2
|
2
2
|
Name: nystrom_ncut
|
3
|
-
Version: 0.3.
|
3
|
+
Version: 0.3.1
|
4
4
|
Summary: Normalized Cut and Nyström Approximation
|
5
5
|
Author-email: Huzheng Yang <huze.yann@gmail.com>, Wentinn Liao <wentinn.liao@gmail.com>
|
6
6
|
Project-URL: Documentation, https://github.com/JophiArcana/Nystrom-NCUT/
|
@@ -0,0 +1,102 @@
|
|
1
|
+
from typing import Literal
|
2
|
+
|
3
|
+
import torch
|
4
|
+
import torch.nn.functional as Fn
|
5
|
+
|
6
|
+
from ..common import (
|
7
|
+
default_device,
|
8
|
+
)
|
9
|
+
from .transformer_mixin import (
|
10
|
+
TorchTransformerMixin,
|
11
|
+
)
|
12
|
+
|
13
|
+
|
14
|
+
class AxisAlign(TorchTransformerMixin):
|
15
|
+
"""Multiclass Spectral Clustering, SX Yu, J Shi, 2003
|
16
|
+
Args:
|
17
|
+
max_iter (int, optional): Maximum number of iterations.
|
18
|
+
"""
|
19
|
+
SortOptions = Literal["count", "norm", "marginal_norm"]
|
20
|
+
|
21
|
+
def __init__(
|
22
|
+
self,
|
23
|
+
sort_method: SortOptions = "norm",
|
24
|
+
max_iter: int = 100,
|
25
|
+
):
|
26
|
+
self.sort_method: AxisAlign.SortOptions = sort_method
|
27
|
+
self.max_iter: int = max_iter
|
28
|
+
|
29
|
+
self.R: torch.Tensor = None
|
30
|
+
|
31
|
+
def fit(self, X: torch.Tensor) -> "AxisAlign":
|
32
|
+
# Normalize eigenvectors
|
33
|
+
with default_device(X.device):
|
34
|
+
d = X.shape[-1]
|
35
|
+
normalized_X = Fn.normalize(X, p=2, dim=-1) # float: [... x n x d]
|
36
|
+
|
37
|
+
# Initialize R matrix with the first column from a random row of EigenVectors
|
38
|
+
def get_idx(idx: torch.Tensor) -> torch.Tensor:
|
39
|
+
return torch.gather(normalized_X, -2, idx[..., None, None].expand([-1] * (X.ndim - 2) + [1, d]))[..., 0, :]
|
40
|
+
|
41
|
+
self.R = torch.empty((*X.shape[:-2], d, d)) # float: [... x d x d]
|
42
|
+
mask = torch.all(torch.isfinite(normalized_X), dim=-1) # bool: [... x n]
|
43
|
+
start_idx = torch.argmax(mask.to(torch.float) + torch.rand(mask.shape), dim=-1) # int: [...]
|
44
|
+
self.R[..., 0, :] = get_idx(start_idx)
|
45
|
+
|
46
|
+
# Loop to populate R with k orthogonal directions
|
47
|
+
c = torch.zeros(X.shape[:-1]) # float: [... x n]
|
48
|
+
for i in range(1, d):
|
49
|
+
c += torch.abs(normalized_X @ self.R[..., i - 1, :, None])[..., 0]
|
50
|
+
self.R[..., i, :] = get_idx(torch.argmin(c.nan_to_num(nan=torch.inf), dim=-1))
|
51
|
+
|
52
|
+
# Iterative optimization loop
|
53
|
+
normalized_X = torch.nan_to_num(normalized_X, nan=0.0)
|
54
|
+
idx, prev_objective = None, torch.inf
|
55
|
+
for _ in range(self.max_iter):
|
56
|
+
# Discretize the projected eigenvectors
|
57
|
+
idx = torch.argmax(normalized_X @ self.R.mT, dim=-1) # int: [... x n]
|
58
|
+
M = torch.sum((idx[..., None] == torch.arange(d))[..., None] * normalized_X[..., :, None, :], dim=-3) # float: [... x d x d]
|
59
|
+
|
60
|
+
# Check for convergence
|
61
|
+
objective = torch.norm(M)
|
62
|
+
if torch.abs(objective - prev_objective) < torch.finfo(torch.float32).eps:
|
63
|
+
break
|
64
|
+
prev_objective = objective
|
65
|
+
|
66
|
+
# SVD decomposition to compute the next R
|
67
|
+
U, S, Vh = torch.linalg.svd(M, full_matrices=False)
|
68
|
+
self.R = U @ Vh
|
69
|
+
|
70
|
+
# Permute the rotation matrix so the dimensions are sorted in descending cluster significance
|
71
|
+
match self.sort_method:
|
72
|
+
case "count":
|
73
|
+
sort_metric = torch.sum((idx[..., None] == torch.arange(d)), dim=-2)
|
74
|
+
case "norm":
|
75
|
+
rotated_X = torch.nan_to_num(X @ self.R.mT, nan=0.0)
|
76
|
+
sort_metric = torch.linalg.norm(rotated_X, dim=-2)
|
77
|
+
case "marginal_norm":
|
78
|
+
rotated_X = torch.nan_to_num(X @ self.R.mT, nan=0.0)
|
79
|
+
sort_metric = torch.sum((idx[..., None] == torch.arange(d)) * (torch.gather(rotated_X, -1, idx[..., None]) ** 2), dim=-2)
|
80
|
+
case _:
|
81
|
+
raise ValueError(f"Invalid sort method {self.sort_method}.")
|
82
|
+
|
83
|
+
order = torch.argsort(sort_metric, dim=-1, descending=True)
|
84
|
+
self.R = torch.gather(self.R, -2, order[..., None].expand([-1] * order.ndim + [d]))
|
85
|
+
return self
|
86
|
+
|
87
|
+
def transform(self, X: torch.Tensor, normalize: bool = True, hard: bool = False) -> torch.Tensor:
|
88
|
+
"""
|
89
|
+
Args:
|
90
|
+
X (torch.Tensor): continuous eigenvectors from NCUT, shape (n, k)
|
91
|
+
normalize (bool): whether to normalize input features before rotating
|
92
|
+
hard (bool): whether to return cluster indices of input features or just the rotated features
|
93
|
+
Returns:
|
94
|
+
torch.Tensor: Discretized eigenvectors, shape (n, k), each row is a one-hot vector.
|
95
|
+
"""
|
96
|
+
if normalize:
|
97
|
+
X = Fn.normalize(X, p=2, dim=-1)
|
98
|
+
rotated_X = X @ self.R.mT
|
99
|
+
return torch.argmax(rotated_X, dim=-1) if hard else rotated_X
|
100
|
+
|
101
|
+
def fit_transform(self, X: torch.Tensor, normalize: bool = True, hard: bool = False) -> torch.Tensor:
|
102
|
+
return self.fit(X).transform(X, normalize=normalize, hard=hard)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.2
|
2
2
|
Name: nystrom_ncut
|
3
|
-
Version: 0.3.
|
3
|
+
Version: 0.3.1
|
4
4
|
Summary: Normalized Cut and Nyström Approximation
|
5
5
|
Author-email: Huzheng Yang <huze.yann@gmail.com>, Wentinn Liao <wentinn.liao@gmail.com>
|
6
6
|
Project-URL: Documentation, https://github.com/JophiArcana/Nystrom-NCUT/
|
@@ -1,90 +0,0 @@
|
|
1
|
-
from typing import Literal
|
2
|
-
|
3
|
-
import torch
|
4
|
-
import torch.nn.functional as Fn
|
5
|
-
|
6
|
-
from .transformer_mixin import (
|
7
|
-
TorchTransformerMixin,
|
8
|
-
)
|
9
|
-
|
10
|
-
|
11
|
-
class AxisAlign(TorchTransformerMixin):
|
12
|
-
"""Multiclass Spectral Clustering, SX Yu, J Shi, 2003
|
13
|
-
Args:
|
14
|
-
max_iter (int, optional): Maximum number of iterations.
|
15
|
-
"""
|
16
|
-
SortOptions = Literal["count", "norm", "marginal_norm"]
|
17
|
-
|
18
|
-
def __init__(
|
19
|
-
self,
|
20
|
-
sort_method: SortOptions = "norm",
|
21
|
-
max_iter: int = 100,
|
22
|
-
):
|
23
|
-
self.sort_method: AxisAlign.SortOptions = sort_method
|
24
|
-
self.max_iter: int = max_iter
|
25
|
-
|
26
|
-
self.R: torch.Tensor = None
|
27
|
-
|
28
|
-
def fit(self, X: torch.Tensor) -> "AxisAlign":
|
29
|
-
# Normalize eigenvectors
|
30
|
-
n, d = X.shape
|
31
|
-
normalized_X = Fn.normalize(X, p=2, dim=-1)
|
32
|
-
|
33
|
-
# Initialize R matrix with the first column from a random row of EigenVectors
|
34
|
-
self.R = torch.empty((d, d), device=X.device)
|
35
|
-
self.R[0] = normalized_X[torch.randint(0, n, (), device=X.device)]
|
36
|
-
|
37
|
-
# Loop to populate R with k orthogonal directions
|
38
|
-
c = torch.zeros((n,), device=X.device)
|
39
|
-
for i in range(1, d):
|
40
|
-
c += torch.abs(normalized_X @ self.R[i - 1])
|
41
|
-
self.R[i] = normalized_X[torch.argmin(c, dim=0)]
|
42
|
-
|
43
|
-
# Iterative optimization loop
|
44
|
-
idx, prev_objective = None, torch.inf
|
45
|
-
for _ in range(self.max_iter):
|
46
|
-
# Discretize the projected eigenvectors
|
47
|
-
idx = torch.argmax(normalized_X @ self.R.mT, dim=-1)
|
48
|
-
M = torch.zeros((d, d), device=X.device).index_add_(0, idx, normalized_X)
|
49
|
-
|
50
|
-
# Check for convergence
|
51
|
-
objective = torch.norm(M)
|
52
|
-
if torch.abs(objective - prev_objective) < torch.finfo(torch.float32).eps:
|
53
|
-
break
|
54
|
-
prev_objective = objective
|
55
|
-
|
56
|
-
# SVD decomposition to compute the next R
|
57
|
-
U, S, Vh = torch.linalg.svd(M, full_matrices=False)
|
58
|
-
self.R = U @ Vh
|
59
|
-
|
60
|
-
# Permute the rotation matrix so the dimensions are sorted in descending cluster significance
|
61
|
-
if self.sort_method == "count":
|
62
|
-
sort_metric = torch.bincount(idx, minlength=d)
|
63
|
-
elif self.sort_method == "norm":
|
64
|
-
rotated_X = X @ self.R.mT
|
65
|
-
sort_metric = torch.linalg.norm(rotated_X, dim=0)
|
66
|
-
elif self.sort_method == "marginal_norm":
|
67
|
-
rotated_X = X @ self.R.mT
|
68
|
-
sort_metric = torch.zeros((d,), device=X.device).index_add_(0, idx, rotated_X[range(n), idx] ** 2)
|
69
|
-
else:
|
70
|
-
raise ValueError(f"Invalid sort method {self.sort_method}.")
|
71
|
-
|
72
|
-
self.R = self.R[torch.argsort(sort_metric, dim=0, descending=True)]
|
73
|
-
return self
|
74
|
-
|
75
|
-
def transform(self, X: torch.Tensor, normalize: bool = True, hard: bool = False) -> torch.Tensor:
|
76
|
-
"""
|
77
|
-
Args:
|
78
|
-
X (torch.Tensor): continuous eigenvectors from NCUT, shape (n, k)
|
79
|
-
normalize (bool): whether to normalize input features before rotating
|
80
|
-
hard (bool): whether to return cluster indices of input features or just the rotated features
|
81
|
-
Returns:
|
82
|
-
torch.Tensor: Discretized eigenvectors, shape (n, k), each row is a one-hot vector.
|
83
|
-
"""
|
84
|
-
if normalize:
|
85
|
-
X = Fn.normalize(X, p=2, dim=1)
|
86
|
-
rotated_X = X @ self.R.mT
|
87
|
-
return torch.argmax(rotated_X, dim=1) if hard else rotated_X
|
88
|
-
|
89
|
-
def fit_transform(self, X: torch.Tensor, normalize: bool = True, hard: bool = False) -> torch.Tensor:
|
90
|
-
return self.fit(X).transform(X, normalize=normalize, hard=hard)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|