nystrom-ncut 0.1.10__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nystrom_ncut/nystrom/normalized_cut.py +1 -1
- nystrom_ncut/transformer/axis_align.py +8 -5
- {nystrom_ncut-0.1.10.dist-info → nystrom_ncut-0.2.0.dist-info}/METADATA +1 -1
- {nystrom_ncut-0.1.10.dist-info → nystrom_ncut-0.2.0.dist-info}/RECORD +7 -7
- {nystrom_ncut-0.1.10.dist-info → nystrom_ncut-0.2.0.dist-info}/LICENSE +0 -0
- {nystrom_ncut-0.1.10.dist-info → nystrom_ncut-0.2.0.dist-info}/WHEEL +0 -0
- {nystrom_ncut-0.1.10.dist-info → nystrom_ncut-0.2.0.dist-info}/top_level.txt +0 -0
@@ -104,7 +104,7 @@ class NCut(OnlineNystromSubsampleFit):
|
|
104
104
|
n_components: int = 100,
|
105
105
|
affinity_focal_gamma: float = 1.0,
|
106
106
|
distance: DistanceOptions = "cosine",
|
107
|
-
adaptive_scaling: bool =
|
107
|
+
adaptive_scaling: bool = False,
|
108
108
|
sample_config: SampleConfig = SampleConfig(),
|
109
109
|
eig_solver: EigSolverOptions = "svd_lowrank",
|
110
110
|
chunk_size: int = 8192,
|
@@ -46,7 +46,7 @@ class AxisAlign(TorchTransformerMixin):
|
|
46
46
|
for _ in range(self.max_iter):
|
47
47
|
# Discretize the projected eigenvectors
|
48
48
|
idx = torch.argmax(normalized_X @ self.R.mT, dim=-1)
|
49
|
-
M = torch.zeros((d, d)).index_add_(0, idx, normalized_X)
|
49
|
+
M = torch.zeros((d, d), device=X.device).index_add_(0, idx, normalized_X)
|
50
50
|
|
51
51
|
# Check for convergence
|
52
52
|
objective = torch.norm(M)
|
@@ -62,23 +62,26 @@ class AxisAlign(TorchTransformerMixin):
|
|
62
62
|
if self.sort_method == "count":
|
63
63
|
sort_metric = torch.bincount(idx, minlength=d)
|
64
64
|
elif self.sort_method == "norm":
|
65
|
-
sort_metric = torch.linalg.norm(X @ self.R.mT,
|
65
|
+
sort_metric = torch.linalg.norm(X @ self.R.mT, dim=0)
|
66
66
|
else:
|
67
67
|
raise ValueError(f"Invalid sort method {self.sort_method}.")
|
68
68
|
|
69
69
|
self.R = self.R[torch.argsort(sort_metric, dim=0, descending=True)]
|
70
70
|
return self
|
71
71
|
|
72
|
-
def transform(self, X: torch.Tensor, hard: bool = False) -> torch.Tensor:
|
72
|
+
def transform(self, X: torch.Tensor, normalize: bool = True, hard: bool = False) -> torch.Tensor:
|
73
73
|
"""
|
74
74
|
Args:
|
75
75
|
X (torch.Tensor): continuous eigenvectors from NCUT, shape (n, k)
|
76
|
+
normalize (bool): whether to normalize input features before rotating
|
76
77
|
hard (bool): whether to return cluster indices of input features or just the rotated features
|
77
78
|
Returns:
|
78
79
|
torch.Tensor: Discretized eigenvectors, shape (n, k), each row is a one-hot vector.
|
79
80
|
"""
|
81
|
+
if normalize:
|
82
|
+
X = Fn.normalize(X, p=2, dim=1)
|
80
83
|
rotated_X = X @ self.R.mT
|
81
84
|
return torch.argmax(rotated_X, dim=1) if hard else rotated_X
|
82
85
|
|
83
|
-
def fit_transform(self, X: torch.Tensor, hard: bool = False) -> torch.Tensor:
|
84
|
-
return self.fit(X).transform(X, hard=hard)
|
86
|
+
def fit_transform(self, X: torch.Tensor, normalize: bool = True, hard: bool = False) -> torch.Tensor:
|
87
|
+
return self.fit(X).transform(X, normalize=normalize, hard=hard)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.2
|
2
2
|
Name: nystrom_ncut
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.2.0
|
4
4
|
Summary: Normalized Cut and Nyström Approximation
|
5
5
|
Author-email: Huzheng Yang <huze.yann@gmail.com>, Wentinn Liao <wentinn.liao@gmail.com>
|
6
6
|
Project-URL: Documentation, https://github.com/JophiArcana/Nystrom-NCUT/
|
@@ -6,13 +6,13 @@ nystrom_ncut/sampling_utils.py,sha256=oMmhFcd_N_D15Ht7F0rCGPSgLeitJszAKMD3ICKwHN
|
|
6
6
|
nystrom_ncut/visualize_utils.py,sha256=d3VXjzJPZPPyUMg_b8hKLQoBaRWvutu6u7l36S2gmIM,23007
|
7
7
|
nystrom_ncut/nystrom/__init__.py,sha256=lAoO00i4FG5xqGKDO_OYcSvO4qPK64x_X_hDNBvuLUc,105
|
8
8
|
nystrom_ncut/nystrom/distance_realization.py,sha256=InajllGtRVnLVlZoipZNbHFTGHaTs3zxizKe3kI2Los,5815
|
9
|
-
nystrom_ncut/nystrom/normalized_cut.py,sha256=
|
9
|
+
nystrom_ncut/nystrom/normalized_cut.py,sha256=2ocwc4U3A6GGFs0cuL0DO1yNvt59SJ3uDtj00U0foPM,5906
|
10
10
|
nystrom_ncut/nystrom/nystrom_utils.py,sha256=5w-2GAMb7b6ArZdPEnAnKPFFrsbHSfC-S78cvrR6O20,12806
|
11
11
|
nystrom_ncut/transformer/__init__.py,sha256=jjXjcNp3LrxeF6mqG9VY5k3asrqaY6bXzJz6wTpH78Q,105
|
12
|
-
nystrom_ncut/transformer/axis_align.py,sha256=
|
12
|
+
nystrom_ncut/transformer/axis_align.py,sha256=UHPbZVs3XFDvxAQHJC2La8W534k3nJ776DEfQeJVKxg,3297
|
13
13
|
nystrom_ncut/transformer/transformer_mixin.py,sha256=YAjrDWTL5Hjnk9J2OsoxvtwT2N0u8IdgMSx0rRFmZzE,1653
|
14
|
-
nystrom_ncut-0.
|
15
|
-
nystrom_ncut-0.
|
16
|
-
nystrom_ncut-0.
|
17
|
-
nystrom_ncut-0.
|
18
|
-
nystrom_ncut-0.
|
14
|
+
nystrom_ncut-0.2.0.dist-info/LICENSE,sha256=2bm9uFabQZ3Ykb_SaSU_uUbAj2-htc6WJQmS_65qD00,1073
|
15
|
+
nystrom_ncut-0.2.0.dist-info/METADATA,sha256=rKIYYECCawfOcjVLx6OZJAo5ngi2uKriNLbgKVEs_IY,6058
|
16
|
+
nystrom_ncut-0.2.0.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
17
|
+
nystrom_ncut-0.2.0.dist-info/top_level.txt,sha256=gM8IWWHYysIRTCvCTcdS4RShOyl9pxpylgSwPUZR2XM,22
|
18
|
+
nystrom_ncut-0.2.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|