nystrom-ncut 0.1.10__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -104,7 +104,7 @@ class NCut(OnlineNystromSubsampleFit):
104
104
  n_components: int = 100,
105
105
  affinity_focal_gamma: float = 1.0,
106
106
  distance: DistanceOptions = "cosine",
107
- adaptive_scaling: bool = True,
107
+ adaptive_scaling: bool = False,
108
108
  sample_config: SampleConfig = SampleConfig(),
109
109
  eig_solver: EigSolverOptions = "svd_lowrank",
110
110
  chunk_size: int = 8192,
@@ -46,7 +46,7 @@ class AxisAlign(TorchTransformerMixin):
46
46
  for _ in range(self.max_iter):
47
47
  # Discretize the projected eigenvectors
48
48
  idx = torch.argmax(normalized_X @ self.R.mT, dim=-1)
49
- M = torch.zeros((d, d)).index_add_(0, idx, normalized_X)
49
+ M = torch.zeros((d, d), device=X.device).index_add_(0, idx, normalized_X)
50
50
 
51
51
  # Check for convergence
52
52
  objective = torch.norm(M)
@@ -62,23 +62,26 @@ class AxisAlign(TorchTransformerMixin):
62
62
  if self.sort_method == "count":
63
63
  sort_metric = torch.bincount(idx, minlength=d)
64
64
  elif self.sort_method == "norm":
65
- sort_metric = torch.linalg.norm(X @ self.R.mT, p=2, dim=0)
65
+ sort_metric = torch.linalg.norm(X @ self.R.mT, dim=0)
66
66
  else:
67
67
  raise ValueError(f"Invalid sort method {self.sort_method}.")
68
68
 
69
69
  self.R = self.R[torch.argsort(sort_metric, dim=0, descending=True)]
70
70
  return self
71
71
 
72
- def transform(self, X: torch.Tensor, hard: bool = False) -> torch.Tensor:
72
+ def transform(self, X: torch.Tensor, normalize: bool = True, hard: bool = False) -> torch.Tensor:
73
73
  """
74
74
  Args:
75
75
  X (torch.Tensor): continuous eigenvectors from NCUT, shape (n, k)
76
+ normalize (bool): whether to normalize input features before rotating
76
77
  hard (bool): whether to return cluster indices of input features or just the rotated features
77
78
  Returns:
78
79
  torch.Tensor: Discretized eigenvectors, shape (n, k), each row is a one-hot vector.
79
80
  """
81
+ if normalize:
82
+ X = Fn.normalize(X, p=2, dim=1)
80
83
  rotated_X = X @ self.R.mT
81
84
  return torch.argmax(rotated_X, dim=1) if hard else rotated_X
82
85
 
83
- def fit_transform(self, X: torch.Tensor, hard: bool = False) -> torch.Tensor:
84
- return self.fit(X).transform(X, hard=hard)
86
+ def fit_transform(self, X: torch.Tensor, normalize: bool = True, hard: bool = False) -> torch.Tensor:
87
+ return self.fit(X).transform(X, normalize=normalize, hard=hard)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: nystrom_ncut
3
- Version: 0.1.10
3
+ Version: 0.2.0
4
4
  Summary: Normalized Cut and Nyström Approximation
5
5
  Author-email: Huzheng Yang <huze.yann@gmail.com>, Wentinn Liao <wentinn.liao@gmail.com>
6
6
  Project-URL: Documentation, https://github.com/JophiArcana/Nystrom-NCUT/
@@ -6,13 +6,13 @@ nystrom_ncut/sampling_utils.py,sha256=oMmhFcd_N_D15Ht7F0rCGPSgLeitJszAKMD3ICKwHN
6
6
  nystrom_ncut/visualize_utils.py,sha256=d3VXjzJPZPPyUMg_b8hKLQoBaRWvutu6u7l36S2gmIM,23007
7
7
  nystrom_ncut/nystrom/__init__.py,sha256=lAoO00i4FG5xqGKDO_OYcSvO4qPK64x_X_hDNBvuLUc,105
8
8
  nystrom_ncut/nystrom/distance_realization.py,sha256=InajllGtRVnLVlZoipZNbHFTGHaTs3zxizKe3kI2Los,5815
9
- nystrom_ncut/nystrom/normalized_cut.py,sha256=5aR-CbRAWQVOA1FlQCuxSKEik9tR9sNLsJVBA7_LXyE,5905
9
+ nystrom_ncut/nystrom/normalized_cut.py,sha256=2ocwc4U3A6GGFs0cuL0DO1yNvt59SJ3uDtj00U0foPM,5906
10
10
  nystrom_ncut/nystrom/nystrom_utils.py,sha256=5w-2GAMb7b6ArZdPEnAnKPFFrsbHSfC-S78cvrR6O20,12806
11
11
  nystrom_ncut/transformer/__init__.py,sha256=jjXjcNp3LrxeF6mqG9VY5k3asrqaY6bXzJz6wTpH78Q,105
12
- nystrom_ncut/transformer/axis_align.py,sha256=pX7wk4O6fj-CwRv1TYyPYXsTmmXUtQ5q0c5fDQBVE6Q,3068
12
+ nystrom_ncut/transformer/axis_align.py,sha256=UHPbZVs3XFDvxAQHJC2La8W534k3nJ776DEfQeJVKxg,3297
13
13
  nystrom_ncut/transformer/transformer_mixin.py,sha256=YAjrDWTL5Hjnk9J2OsoxvtwT2N0u8IdgMSx0rRFmZzE,1653
14
- nystrom_ncut-0.1.10.dist-info/LICENSE,sha256=2bm9uFabQZ3Ykb_SaSU_uUbAj2-htc6WJQmS_65qD00,1073
15
- nystrom_ncut-0.1.10.dist-info/METADATA,sha256=klm6vyp2lsQA82we85bhvK6Xz2pAPhBeJw57lYN0abA,6059
16
- nystrom_ncut-0.1.10.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
17
- nystrom_ncut-0.1.10.dist-info/top_level.txt,sha256=gM8IWWHYysIRTCvCTcdS4RShOyl9pxpylgSwPUZR2XM,22
18
- nystrom_ncut-0.1.10.dist-info/RECORD,,
14
+ nystrom_ncut-0.2.0.dist-info/LICENSE,sha256=2bm9uFabQZ3Ykb_SaSU_uUbAj2-htc6WJQmS_65qD00,1073
15
+ nystrom_ncut-0.2.0.dist-info/METADATA,sha256=rKIYYECCawfOcjVLx6OZJAo5ngi2uKriNLbgKVEs_IY,6058
16
+ nystrom_ncut-0.2.0.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
17
+ nystrom_ncut-0.2.0.dist-info/top_level.txt,sha256=gM8IWWHYysIRTCvCTcdS4RShOyl9pxpylgSwPUZR2XM,22
18
+ nystrom_ncut-0.2.0.dist-info/RECORD,,