nystrom-ncut 0.1.8__py3-none-any.whl → 0.1.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,5 @@
1
1
  import random
2
+ from typing import Literal
2
3
 
3
4
  import torch
4
5
  import torch.nn.functional as Fn
@@ -13,31 +14,39 @@ class AxisAlign(TorchTransformerMixin):
13
14
  Args:
14
15
  max_iter (int, optional): Maximum number of iterations.
15
16
  """
16
- def __init__(self, max_iter: int = 100):
17
- self.max_iter = max_iter
17
+ SortOptions = Literal["count", "norm"]
18
+
19
+ def __init__(
20
+ self,
21
+ sort_method: SortOptions = "norm",
22
+ max_iter: int = 100,
23
+ ):
24
+ self.sort_method: AxisAlign.SortOptions = sort_method
25
+ self.max_iter: int = max_iter
26
+
18
27
  self.R: torch.Tensor = None
19
28
 
20
29
  def fit(self, X: torch.Tensor) -> "AxisAlign":
21
30
  # Normalize eigenvectors
22
31
  n, d = X.shape
23
- X = Fn.normalize(X, p=2, dim=-1)
32
+ normalized_X = Fn.normalize(X, p=2, dim=-1)
24
33
 
25
34
  # Initialize R matrix with the first column from a random row of EigenVectors
26
35
  self.R = torch.empty((d, d), device=X.device)
27
- self.R[0] = X[random.randint(0, n - 1)]
36
+ self.R[0] = normalized_X[random.randint(0, n - 1)]
28
37
 
29
38
  # Loop to populate R with k orthogonal directions
30
39
  c = torch.zeros((n,), device=X.device)
31
40
  for i in range(1, d):
32
- c += torch.abs(X @ self.R[i - 1])
33
- self.R[i] = X[torch.argmin(c, dim=0)]
41
+ c += torch.abs(normalized_X @ self.R[i - 1])
42
+ self.R[i] = normalized_X[torch.argmin(c, dim=0)]
34
43
 
35
44
  # Iterative optimization loop
36
45
  idx, prev_objective = None, torch.inf
37
46
  for _ in range(self.max_iter):
38
47
  # Discretize the projected eigenvectors
39
- idx = torch.argmax(X @ self.R.mT, dim=-1)
40
- M = torch.zeros((d, d)).index_add_(0, idx, X)
48
+ idx = torch.argmax(normalized_X @ self.R.mT, dim=-1)
49
+ M = torch.zeros((d, d)).index_add_(0, idx, normalized_X)
41
50
 
42
51
  # Check for convergence
43
52
  objective = torch.norm(M)
@@ -49,8 +58,15 @@ class AxisAlign(TorchTransformerMixin):
49
58
  U, S, Vh = torch.linalg.svd(M, full_matrices=False)
50
59
  self.R = U @ Vh
51
60
 
52
- # Permute the rotation matrix so the dimensions are sorted in descending cluster counts
53
- self.R = self.R[torch.argsort(torch.bincount(idx, minlength=d), dim=0, descending=True)]
61
+ # Permute the rotation matrix so the dimensions are sorted in descending cluster significance
62
+ if self.sort_method == "count":
63
+ sort_metric = torch.bincount(idx, minlength=d)
64
+ elif self.sort_method == "norm":
65
+ sort_metric = torch.linalg.norm(X @ self.R.mT, p=2, dim=0)
66
+ else:
67
+ raise ValueError(f"Invalid sort method {self.sort_method}.")
68
+
69
+ self.R = self.R[torch.argsort(sort_metric, dim=0, descending=True)]
54
70
  return self
55
71
 
56
72
  def transform(self, X: torch.Tensor, hard: bool = False) -> torch.Tensor:
@@ -2,10 +2,9 @@ from abc import abstractmethod
2
2
  from typing import Any
3
3
 
4
4
  import torch
5
- from sklearn.base import TransformerMixin, BaseEstimator
6
5
 
7
6
 
8
- class TorchTransformerMixin(TransformerMixin, BaseEstimator):
7
+ class TorchTransformerMixin:
9
8
  """Mixin class for all transformers in scikit-learn.
10
9
 
11
10
  This mixin defines the following functionality:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: nystrom_ncut
3
- Version: 0.1.8
3
+ Version: 0.1.10
4
4
  Summary: Normalized Cut and Nyström Approximation
5
5
  Author-email: Huzheng Yang <huze.yann@gmail.com>, Wentinn Liao <wentinn.liao@gmail.com>
6
6
  Project-URL: Documentation, https://github.com/JophiArcana/Nystrom-NCUT/
@@ -9,10 +9,10 @@ nystrom_ncut/nystrom/distance_realization.py,sha256=InajllGtRVnLVlZoipZNbHFTGHaT
9
9
  nystrom_ncut/nystrom/normalized_cut.py,sha256=5aR-CbRAWQVOA1FlQCuxSKEik9tR9sNLsJVBA7_LXyE,5905
10
10
  nystrom_ncut/nystrom/nystrom_utils.py,sha256=5w-2GAMb7b6ArZdPEnAnKPFFrsbHSfC-S78cvrR6O20,12806
11
11
  nystrom_ncut/transformer/__init__.py,sha256=jjXjcNp3LrxeF6mqG9VY5k3asrqaY6bXzJz6wTpH78Q,105
12
- nystrom_ncut/transformer/axis_align.py,sha256=bea_5QVm5gY0o_rZB9jOr5fbXWCL_IWB9FI-wSygx00,2513
13
- nystrom_ncut/transformer/transformer_mixin.py,sha256=fTNtDFYPw2Fc8mjvK2xNHOw5mCkbO0usUpOnnJdyr5M,1743
14
- nystrom_ncut-0.1.8.dist-info/LICENSE,sha256=2bm9uFabQZ3Ykb_SaSU_uUbAj2-htc6WJQmS_65qD00,1073
15
- nystrom_ncut-0.1.8.dist-info/METADATA,sha256=9hsyObQahN0v45sjnBh9-tJuwY-62T2iU_7vEl3MqkY,6058
16
- nystrom_ncut-0.1.8.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
17
- nystrom_ncut-0.1.8.dist-info/top_level.txt,sha256=gM8IWWHYysIRTCvCTcdS4RShOyl9pxpylgSwPUZR2XM,22
18
- nystrom_ncut-0.1.8.dist-info/RECORD,,
12
+ nystrom_ncut/transformer/axis_align.py,sha256=pX7wk4O6fj-CwRv1TYyPYXsTmmXUtQ5q0c5fDQBVE6Q,3068
13
+ nystrom_ncut/transformer/transformer_mixin.py,sha256=YAjrDWTL5Hjnk9J2OsoxvtwT2N0u8IdgMSx0rRFmZzE,1653
14
+ nystrom_ncut-0.1.10.dist-info/LICENSE,sha256=2bm9uFabQZ3Ykb_SaSU_uUbAj2-htc6WJQmS_65qD00,1073
15
+ nystrom_ncut-0.1.10.dist-info/METADATA,sha256=klm6vyp2lsQA82we85bhvK6Xz2pAPhBeJw57lYN0abA,6059
16
+ nystrom_ncut-0.1.10.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
17
+ nystrom_ncut-0.1.10.dist-info/top_level.txt,sha256=gM8IWWHYysIRTCvCTcdS4RShOyl9pxpylgSwPUZR2XM,22
18
+ nystrom_ncut-0.1.10.dist-info/RECORD,,