nystrom-ncut 0.1.10__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nystrom_ncut/nystrom/normalized_cut.py +1 -1
- nystrom_ncut/transformer/axis_align.py +14 -8
- {nystrom_ncut-0.1.10.dist-info → nystrom_ncut-0.2.1.dist-info}/METADATA +1 -1
- {nystrom_ncut-0.1.10.dist-info → nystrom_ncut-0.2.1.dist-info}/RECORD +7 -7
- {nystrom_ncut-0.1.10.dist-info → nystrom_ncut-0.2.1.dist-info}/LICENSE +0 -0
- {nystrom_ncut-0.1.10.dist-info → nystrom_ncut-0.2.1.dist-info}/WHEEL +0 -0
- {nystrom_ncut-0.1.10.dist-info → nystrom_ncut-0.2.1.dist-info}/top_level.txt +0 -0
@@ -104,7 +104,7 @@ class NCut(OnlineNystromSubsampleFit):
|
|
104
104
|
n_components: int = 100,
|
105
105
|
affinity_focal_gamma: float = 1.0,
|
106
106
|
distance: DistanceOptions = "cosine",
|
107
|
-
adaptive_scaling: bool =
|
107
|
+
adaptive_scaling: bool = False,
|
108
108
|
sample_config: SampleConfig = SampleConfig(),
|
109
109
|
eig_solver: EigSolverOptions = "svd_lowrank",
|
110
110
|
chunk_size: int = 8192,
|
@@ -1,4 +1,3 @@
|
|
1
|
-
import random
|
2
1
|
from typing import Literal
|
3
2
|
|
4
3
|
import torch
|
@@ -14,7 +13,7 @@ class AxisAlign(TorchTransformerMixin):
|
|
14
13
|
Args:
|
15
14
|
max_iter (int, optional): Maximum number of iterations.
|
16
15
|
"""
|
17
|
-
SortOptions = Literal["count", "norm"]
|
16
|
+
SortOptions = Literal["count", "norm", "marginal_norm"]
|
18
17
|
|
19
18
|
def __init__(
|
20
19
|
self,
|
@@ -33,7 +32,7 @@ class AxisAlign(TorchTransformerMixin):
|
|
33
32
|
|
34
33
|
# Initialize R matrix with the first column from a random row of EigenVectors
|
35
34
|
self.R = torch.empty((d, d), device=X.device)
|
36
|
-
self.R[0] = normalized_X[
|
35
|
+
self.R[0] = normalized_X[torch.randint(0, n, (), device=X.device)]
|
37
36
|
|
38
37
|
# Loop to populate R with k orthogonal directions
|
39
38
|
c = torch.zeros((n,), device=X.device)
|
@@ -46,7 +45,7 @@ class AxisAlign(TorchTransformerMixin):
|
|
46
45
|
for _ in range(self.max_iter):
|
47
46
|
# Discretize the projected eigenvectors
|
48
47
|
idx = torch.argmax(normalized_X @ self.R.mT, dim=-1)
|
49
|
-
M = torch.zeros((d, d)).index_add_(0, idx, normalized_X)
|
48
|
+
M = torch.zeros((d, d), device=X.device).index_add_(0, idx, normalized_X)
|
50
49
|
|
51
50
|
# Check for convergence
|
52
51
|
objective = torch.norm(M)
|
@@ -62,23 +61,30 @@ class AxisAlign(TorchTransformerMixin):
|
|
62
61
|
if self.sort_method == "count":
|
63
62
|
sort_metric = torch.bincount(idx, minlength=d)
|
64
63
|
elif self.sort_method == "norm":
|
65
|
-
|
64
|
+
rotated_X = X @ self.R.mT
|
65
|
+
sort_metric = torch.linalg.norm(rotated_X, dim=0)
|
66
|
+
elif self.sort_method == "marginal_norm":
|
67
|
+
rotated_X = X @ self.R.mT
|
68
|
+
sort_metric = torch.zeros((d,), device=X.device).index_add_(0, idx, rotated_X[range(n), idx] ** 2)
|
66
69
|
else:
|
67
70
|
raise ValueError(f"Invalid sort method {self.sort_method}.")
|
68
71
|
|
69
72
|
self.R = self.R[torch.argsort(sort_metric, dim=0, descending=True)]
|
70
73
|
return self
|
71
74
|
|
72
|
-
def transform(self, X: torch.Tensor, hard: bool = False) -> torch.Tensor:
|
75
|
+
def transform(self, X: torch.Tensor, normalize: bool = True, hard: bool = False) -> torch.Tensor:
|
73
76
|
"""
|
74
77
|
Args:
|
75
78
|
X (torch.Tensor): continuous eigenvectors from NCUT, shape (n, k)
|
79
|
+
normalize (bool): whether to normalize input features before rotating
|
76
80
|
hard (bool): whether to return cluster indices of input features or just the rotated features
|
77
81
|
Returns:
|
78
82
|
torch.Tensor: Discretized eigenvectors, shape (n, k), each row is a one-hot vector.
|
79
83
|
"""
|
84
|
+
if normalize:
|
85
|
+
X = Fn.normalize(X, p=2, dim=1)
|
80
86
|
rotated_X = X @ self.R.mT
|
81
87
|
return torch.argmax(rotated_X, dim=1) if hard else rotated_X
|
82
88
|
|
83
|
-
def fit_transform(self, X: torch.Tensor, hard: bool = False) -> torch.Tensor:
|
84
|
-
return self.fit(X).transform(X, hard=hard)
|
89
|
+
def fit_transform(self, X: torch.Tensor, normalize: bool = True, hard: bool = False) -> torch.Tensor:
|
90
|
+
return self.fit(X).transform(X, normalize=normalize, hard=hard)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.2
|
2
2
|
Name: nystrom_ncut
|
3
|
-
Version: 0.1
|
3
|
+
Version: 0.2.1
|
4
4
|
Summary: Normalized Cut and Nyström Approximation
|
5
5
|
Author-email: Huzheng Yang <huze.yann@gmail.com>, Wentinn Liao <wentinn.liao@gmail.com>
|
6
6
|
Project-URL: Documentation, https://github.com/JophiArcana/Nystrom-NCUT/
|
@@ -6,13 +6,13 @@ nystrom_ncut/sampling_utils.py,sha256=oMmhFcd_N_D15Ht7F0rCGPSgLeitJszAKMD3ICKwHN
|
|
6
6
|
nystrom_ncut/visualize_utils.py,sha256=d3VXjzJPZPPyUMg_b8hKLQoBaRWvutu6u7l36S2gmIM,23007
|
7
7
|
nystrom_ncut/nystrom/__init__.py,sha256=lAoO00i4FG5xqGKDO_OYcSvO4qPK64x_X_hDNBvuLUc,105
|
8
8
|
nystrom_ncut/nystrom/distance_realization.py,sha256=InajllGtRVnLVlZoipZNbHFTGHaTs3zxizKe3kI2Los,5815
|
9
|
-
nystrom_ncut/nystrom/normalized_cut.py,sha256=
|
9
|
+
nystrom_ncut/nystrom/normalized_cut.py,sha256=2ocwc4U3A6GGFs0cuL0DO1yNvt59SJ3uDtj00U0foPM,5906
|
10
10
|
nystrom_ncut/nystrom/nystrom_utils.py,sha256=5w-2GAMb7b6ArZdPEnAnKPFFrsbHSfC-S78cvrR6O20,12806
|
11
11
|
nystrom_ncut/transformer/__init__.py,sha256=jjXjcNp3LrxeF6mqG9VY5k3asrqaY6bXzJz6wTpH78Q,105
|
12
|
-
nystrom_ncut/transformer/axis_align.py,sha256=
|
12
|
+
nystrom_ncut/transformer/axis_align.py,sha256=6LTR-syJ-f4pcbnMexFmFNn1QADDhH5gka6979YBRrI,3549
|
13
13
|
nystrom_ncut/transformer/transformer_mixin.py,sha256=YAjrDWTL5Hjnk9J2OsoxvtwT2N0u8IdgMSx0rRFmZzE,1653
|
14
|
-
nystrom_ncut-0.1.
|
15
|
-
nystrom_ncut-0.1.
|
16
|
-
nystrom_ncut-0.1.
|
17
|
-
nystrom_ncut-0.1.
|
18
|
-
nystrom_ncut-0.1.
|
14
|
+
nystrom_ncut-0.2.1.dist-info/LICENSE,sha256=2bm9uFabQZ3Ykb_SaSU_uUbAj2-htc6WJQmS_65qD00,1073
|
15
|
+
nystrom_ncut-0.2.1.dist-info/METADATA,sha256=l5t4vEFtPANsQY8PK0YHDJ1tw6dZUulU5daxX9T8QC0,6058
|
16
|
+
nystrom_ncut-0.2.1.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
17
|
+
nystrom_ncut-0.2.1.dist-info/top_level.txt,sha256=gM8IWWHYysIRTCvCTcdS4RShOyl9pxpylgSwPUZR2XM,22
|
18
|
+
nystrom_ncut-0.2.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|