nystrom-ncut 0.1.7__py3-none-any.whl → 0.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nystrom_ncut/__init__.py CHANGED
@@ -1,6 +1,8 @@
1
1
  from .nystrom import (
2
2
  NCut,
3
- axis_align,
3
+ )
4
+ from .transformer import (
5
+ AxisAlign,
4
6
  )
5
7
  from .distance_utils import (
6
8
  distance_from_features,
@@ -3,5 +3,4 @@ from .distance_realization import (
3
3
  )
4
4
  from .normalized_cut import (
5
5
  NCut,
6
- axis_align,
7
6
  )
@@ -1,6 +1,5 @@
1
1
  import einops
2
2
  import torch
3
- import torch.nn.functional as Fn
4
3
 
5
4
  from .nystrom_utils import (
6
5
  EigSolverOptions,
@@ -131,50 +130,3 @@ class NCut(OnlineNystromSubsampleFit):
131
130
  eig_solver=eig_solver,
132
131
  chunk_size=chunk_size,
133
132
  )
134
-
135
-
136
- def axis_align(eigen_vectors: torch.Tensor, max_iter=300):
137
- """Multiclass Spectral Clustering, SX Yu, J Shi, 2003
138
-
139
- Args:
140
- eigen_vectors (torch.Tensor): continuous eigenvectors from NCUT, shape (n, k)
141
- max_iter (int, optional): Maximum number of iterations.
142
-
143
- Returns:
144
- torch.Tensor: Discretized eigenvectors, shape (n, k), each row is a one-hot vector.
145
- """
146
- # Normalize eigenvectors
147
- n, k = eigen_vectors.shape
148
- eigen_vectors = Fn.normalize(eigen_vectors, p=2, dim=-1)
149
-
150
- # Initialize R matrix with the first column from a random row of EigenVectors
151
- R = torch.empty((k, k), device=eigen_vectors.device)
152
- R[0] = eigen_vectors[torch.randint(0, n, (1,))].squeeze()
153
-
154
- # Loop to populate R with k orthogonal directions
155
- c = torch.zeros(n, device=eigen_vectors.device)
156
- for i in range(1, k):
157
- c += torch.abs(eigen_vectors @ R[i - 1])
158
- R[i] = eigen_vectors[torch.argmin(c, dim=0)]
159
-
160
- # Iterative optimization loop
161
- eps = torch.finfo(torch.float32).eps
162
- prev_objective = torch.inf
163
- for _ in range(max_iter):
164
- # Discretize the projected eigenvectors
165
- idx = torch.argmax(eigen_vectors @ R.mT, dim=-1)
166
- M = torch.zeros((k, k)).index_add_(0, idx, eigen_vectors)
167
-
168
- # Compute the NCut value
169
- objective = torch.norm(M)
170
-
171
- # Check for convergence
172
- if torch.abs(objective - prev_objective) < eps:
173
- break
174
- prev_objective = objective
175
-
176
- # SVD decomposition
177
- U, S, Vh = torch.linalg.svd(M, full_matrices=False)
178
- R = U @ Vh
179
-
180
- return Fn.one_hot(idx, num_classes=k).to(torch.float), R
@@ -1,5 +1,6 @@
1
1
  import copy
2
2
  import logging
3
+ from abc import abstractmethod
3
4
  from typing import Literal, Tuple
4
5
 
5
6
  import torch
@@ -14,23 +15,29 @@ from ..sampling_utils import (
14
15
  SampleConfig,
15
16
  subsample_features,
16
17
  )
18
+ from ..transformer import (
19
+ TorchTransformerMixin,
20
+ )
17
21
 
18
22
 
19
23
  EigSolverOptions = Literal["svd_lowrank", "lobpcg", "svd", "eigh"]
20
24
 
21
25
 
22
26
  class OnlineKernel:
27
+ @abstractmethod
23
28
  def fit(self, features: torch.Tensor) -> "OnlineKernel": # [n x d]
24
- raise NotImplementedError()
29
+ """"""
25
30
 
31
+ @abstractmethod
26
32
  def update(self, features: torch.Tensor) -> torch.Tensor: # [m x d] -> [m x n]
27
- raise NotImplementedError()
33
+ """"""
28
34
 
35
+ @abstractmethod
29
36
  def transform(self, features: torch.Tensor = None) -> torch.Tensor: # [m x d] -> [m x n]
30
- raise NotImplementedError()
37
+ """"""
31
38
 
32
39
 
33
- class OnlineNystrom:
40
+ class OnlineNystrom(TorchTransformerMixin):
34
41
  def __init__(
35
42
  self,
36
43
  n_components: int,
@@ -0,0 +1,6 @@
1
+ from .transformer_mixin import (
2
+ TorchTransformerMixin,
3
+ )
4
+ from .axis_align import (
5
+ AxisAlign,
6
+ )
@@ -0,0 +1,84 @@
1
+ import random
2
+ from typing import Literal
3
+
4
+ import torch
5
+ import torch.nn.functional as Fn
6
+
7
+ from .transformer_mixin import (
8
+ TorchTransformerMixin,
9
+ )
10
+
11
+
12
+ class AxisAlign(TorchTransformerMixin):
13
+ """Multiclass Spectral Clustering, SX Yu, J Shi, 2003
14
+ Args:
15
+ max_iter (int, optional): Maximum number of iterations.
16
+ """
17
+ SortOptions = Literal["count", "norm"]
18
+
19
+ def __init__(
20
+ self,
21
+ sort_method: SortOptions = "norm",
22
+ max_iter: int = 100,
23
+ ):
24
+ self.sort_method: AxisAlign.SortOptions = sort_method
25
+ self.max_iter: int = max_iter
26
+
27
+ self.R: torch.Tensor = None
28
+
29
+ def fit(self, X: torch.Tensor) -> "AxisAlign":
30
+ # Normalize eigenvectors
31
+ n, d = X.shape
32
+ normalized_X = Fn.normalize(X, p=2, dim=-1)
33
+
34
+ # Initialize R matrix with the first column from a random row of EigenVectors
35
+ self.R = torch.empty((d, d), device=X.device)
36
+ self.R[0] = normalized_X[random.randint(0, n - 1)]
37
+
38
+ # Loop to populate R with k orthogonal directions
39
+ c = torch.zeros((n,), device=X.device)
40
+ for i in range(1, d):
41
+ c += torch.abs(normalized_X @ self.R[i - 1])
42
+ self.R[i] = normalized_X[torch.argmin(c, dim=0)]
43
+
44
+ # Iterative optimization loop
45
+ idx, prev_objective = None, torch.inf
46
+ for _ in range(self.max_iter):
47
+ # Discretize the projected eigenvectors
48
+ idx = torch.argmax(normalized_X @ self.R.mT, dim=-1)
49
+ M = torch.zeros((d, d)).index_add_(0, idx, normalized_X)
50
+
51
+ # Check for convergence
52
+ objective = torch.norm(M)
53
+ if torch.abs(objective - prev_objective) < torch.finfo(torch.float32).eps:
54
+ break
55
+ prev_objective = objective
56
+
57
+ # SVD decomposition to compute the next R
58
+ U, S, Vh = torch.linalg.svd(M, full_matrices=False)
59
+ self.R = U @ Vh
60
+
61
+ # Permute the rotation matrix so the dimensions are sorted in descending cluster significance
62
+ if self.sort_method == "count":
63
+ sort_metric = torch.bincount(idx, minlength=d)
64
+ elif self.sort_method == "norm":
65
+ sort_metric = torch.linalg.norm(X @ self.R.mT, p=2, dim=0)
66
+ else:
67
+ raise ValueError(f"Invalid sort method {self.sort_method}.")
68
+
69
+ self.R = self.R[torch.argsort(sort_metric, dim=0, descending=True)]
70
+ return self
71
+
72
+ def transform(self, X: torch.Tensor, hard: bool = False) -> torch.Tensor:
73
+ """
74
+ Args:
75
+ X (torch.Tensor): continuous eigenvectors from NCUT, shape (n, k)
76
+ hard (bool): whether to return cluster indices of input features or just the rotated features
77
+ Returns:
78
+ torch.Tensor: Discretized eigenvectors, shape (n, k), each row is a one-hot vector.
79
+ """
80
+ rotated_X = X @ self.R.mT
81
+ return torch.argmax(rotated_X, dim=1) if hard else rotated_X
82
+
83
+ def fit_transform(self, X: torch.Tensor, hard: bool = False) -> torch.Tensor:
84
+ return self.fit(X).transform(X, hard=hard)
@@ -0,0 +1,51 @@
1
+ from abc import abstractmethod
2
+ from typing import Any
3
+
4
+ import torch
5
+ from sklearn.base import TransformerMixin, BaseEstimator
6
+
7
+
8
+ class TorchTransformerMixin(TransformerMixin, BaseEstimator):
9
+ """Mixin class for all transformers in scikit-learn.
10
+
11
+ This mixin defines the following functionality:
12
+
13
+ - a `fit_transform` method that delegates to `fit` and `transform`;
14
+ - a `set_output` method to output `X` as a specific container type.
15
+
16
+ If :term:`get_feature_names_out` is defined, then :class:`BaseEstimator` will
17
+ automatically wrap `transform` and `fit_transform` to follow the `set_output`
18
+ API. See the :ref:`developer_api_set_output` for details.
19
+
20
+ :class:`OneToOneFeatureMixin` and
21
+ :class:`ClassNamePrefixFeaturesOutMixin` are helpful mixins for
22
+ defining :term:`get_feature_names_out`.
23
+
24
+ Examples
25
+ --------
26
+ >>> import numpy as np
27
+ >>> from sklearn.base import BaseEstimator, TransformerMixin
28
+ >>> class MyTransformer(TransformerMixin, BaseEstimator):
29
+ ... def __init__(self, *, param=1):
30
+ ... self.param = param
31
+ ... def fit(self, X, y=None):
32
+ ... return self
33
+ ... def transform(self, X):
34
+ ... return np.full(shape=len(X), fill_value=self.param)
35
+ >>> transformer = MyTransformer()
36
+ >>> X = [[1, 2], [2, 3], [3, 4]]
37
+ >>> transformer.fit_transform(X)
38
+ array([1, 1, 1])
39
+ """
40
+
41
+ @abstractmethod
42
+ def fit(self, X: torch.Tensor, **fit_kwargs: Any) -> "TorchTransformerMixin":
43
+ """"""
44
+
45
+ @abstractmethod
46
+ def transform(self, X: torch.Tensor, **transform_kwargs: Any) -> torch.Tensor:
47
+ """"""
48
+
49
+ @abstractmethod
50
+ def fit_transform(self, X: torch.Tensor, **kwargs: Any) -> torch.Tensor:
51
+ """"""
@@ -1,10 +1,10 @@
1
1
  import logging
2
- from typing import Any, Callable, Dict, Literal
2
+ from typing import Any, Callable, Dict, Literal, Union
3
3
 
4
4
  import numpy as np
5
5
  import torch
6
6
  import torch.nn.functional as Fn
7
- from sklearn.base import BaseEstimator
7
+ from sklearn.base import TransformerMixin, BaseEstimator
8
8
 
9
9
  from .common import (
10
10
  ceildiv,
@@ -152,7 +152,7 @@ def _rgb_with_dimensionality_reduction(
152
152
  rgb_func: Callable[[torch.Tensor, float], torch.Tensor],
153
153
  q: float,
154
154
  knn: int,
155
- reduction: Callable[..., BaseEstimator],
155
+ reduction: Callable[..., Union[TransformerMixin, BaseEstimator]],
156
156
  reduction_dim: int,
157
157
  reduction_kwargs: Dict[str, Any],
158
158
  seed: int,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: nystrom_ncut
3
- Version: 0.1.7
3
+ Version: 0.1.9
4
4
  Summary: Normalized Cut and Nyström Approximation
5
5
  Author-email: Huzheng Yang <huze.yann@gmail.com>, Wentinn Liao <wentinn.liao@gmail.com>
6
6
  Project-URL: Documentation, https://github.com/JophiArcana/Nystrom-NCUT/
@@ -0,0 +1,18 @@
1
+ __init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ nystrom_ncut/__init__.py,sha256=tKq9-2QRNFetckHY77qAaKEMjMCYTYcorS2f74aNtvk,540
3
+ nystrom_ncut/common.py,sha256=_PGJoImSk_Fb_5Ri-e_IsFoCcSfbGS8CxYUUHVoNM50,2036
4
+ nystrom_ncut/distance_utils.py,sha256=p-pYdpRrJsIhzxM_IxUqja7N8okngx52WGXD9pu_Aec,3129
5
+ nystrom_ncut/sampling_utils.py,sha256=oMmhFcd_N_D15Ht7F0rCGPSgLeitJszAKMD3ICKwHNU,3105
6
+ nystrom_ncut/visualize_utils.py,sha256=d3VXjzJPZPPyUMg_b8hKLQoBaRWvutu6u7l36S2gmIM,23007
7
+ nystrom_ncut/nystrom/__init__.py,sha256=lAoO00i4FG5xqGKDO_OYcSvO4qPK64x_X_hDNBvuLUc,105
8
+ nystrom_ncut/nystrom/distance_realization.py,sha256=InajllGtRVnLVlZoipZNbHFTGHaTs3zxizKe3kI2Los,5815
9
+ nystrom_ncut/nystrom/normalized_cut.py,sha256=5aR-CbRAWQVOA1FlQCuxSKEik9tR9sNLsJVBA7_LXyE,5905
10
+ nystrom_ncut/nystrom/nystrom_utils.py,sha256=5w-2GAMb7b6ArZdPEnAnKPFFrsbHSfC-S78cvrR6O20,12806
11
+ nystrom_ncut/transformer/__init__.py,sha256=jjXjcNp3LrxeF6mqG9VY5k3asrqaY6bXzJz6wTpH78Q,105
12
+ nystrom_ncut/transformer/axis_align.py,sha256=pX7wk4O6fj-CwRv1TYyPYXsTmmXUtQ5q0c5fDQBVE6Q,3068
13
+ nystrom_ncut/transformer/transformer_mixin.py,sha256=fTNtDFYPw2Fc8mjvK2xNHOw5mCkbO0usUpOnnJdyr5M,1743
14
+ nystrom_ncut-0.1.9.dist-info/LICENSE,sha256=2bm9uFabQZ3Ykb_SaSU_uUbAj2-htc6WJQmS_65qD00,1073
15
+ nystrom_ncut-0.1.9.dist-info/METADATA,sha256=8ez3ayc8UcBR8R8Ds7nRAKbrEa3766WNDrQwXToQ9ZM,6058
16
+ nystrom_ncut-0.1.9.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
17
+ nystrom_ncut-0.1.9.dist-info/top_level.txt,sha256=gM8IWWHYysIRTCvCTcdS4RShOyl9pxpylgSwPUZR2XM,22
18
+ nystrom_ncut-0.1.9.dist-info/RECORD,,
@@ -1,15 +0,0 @@
1
- __init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- nystrom_ncut/__init__.py,sha256=HifrTcqX2-hYjBDe6xIThHvuIBYMPBA3EzjR8-qPMUM,512
3
- nystrom_ncut/common.py,sha256=_PGJoImSk_Fb_5Ri-e_IsFoCcSfbGS8CxYUUHVoNM50,2036
4
- nystrom_ncut/distance_utils.py,sha256=p-pYdpRrJsIhzxM_IxUqja7N8okngx52WGXD9pu_Aec,3129
5
- nystrom_ncut/sampling_utils.py,sha256=oMmhFcd_N_D15Ht7F0rCGPSgLeitJszAKMD3ICKwHNU,3105
6
- nystrom_ncut/visualize_utils.py,sha256=_J6YjWUsBe0VqW6KXsQx_iPmRCcO-ie0g6t5mD289UI,22957
7
- nystrom_ncut/nystrom/__init__.py,sha256=4EpxD3Cmc8Fif4vo8DG-6FpTfCnNanD5zCZxK3WrMwQ,121
8
- nystrom_ncut/nystrom/distance_realization.py,sha256=InajllGtRVnLVlZoipZNbHFTGHaTs3zxizKe3kI2Los,5815
9
- nystrom_ncut/nystrom/normalized_cut.py,sha256=N-M5wkTo59vpbBfIx8evkSQBxlo4j80qCtuoifxQa_A,7578
10
- nystrom_ncut/nystrom/nystrom_utils.py,sha256=UVs1tC7vnVq2mWSTpcrP4C19x9wDJ77ACht0EltOO2E,12698
11
- nystrom_ncut-0.1.7.dist-info/LICENSE,sha256=2bm9uFabQZ3Ykb_SaSU_uUbAj2-htc6WJQmS_65qD00,1073
12
- nystrom_ncut-0.1.7.dist-info/METADATA,sha256=eb0Q6bwCKC4c5bcuJI_PnIaPW5qFYGSgwbPeuWI7EUk,6058
13
- nystrom_ncut-0.1.7.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
14
- nystrom_ncut-0.1.7.dist-info/top_level.txt,sha256=gM8IWWHYysIRTCvCTcdS4RShOyl9pxpylgSwPUZR2XM,22
15
- nystrom_ncut-0.1.7.dist-info/RECORD,,