nystrom-ncut 0.1.9__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -104,7 +104,7 @@ class NCut(OnlineNystromSubsampleFit):
104
104
  n_components: int = 100,
105
105
  affinity_focal_gamma: float = 1.0,
106
106
  distance: DistanceOptions = "cosine",
107
- adaptive_scaling: bool = True,
107
+ adaptive_scaling: bool = False,
108
108
  sample_config: SampleConfig = SampleConfig(),
109
109
  eig_solver: EigSolverOptions = "svd_lowrank",
110
110
  chunk_size: int = 8192,
@@ -46,7 +46,7 @@ class AxisAlign(TorchTransformerMixin):
46
46
  for _ in range(self.max_iter):
47
47
  # Discretize the projected eigenvectors
48
48
  idx = torch.argmax(normalized_X @ self.R.mT, dim=-1)
49
- M = torch.zeros((d, d)).index_add_(0, idx, normalized_X)
49
+ M = torch.zeros((d, d), device=X.device).index_add_(0, idx, normalized_X)
50
50
 
51
51
  # Check for convergence
52
52
  objective = torch.norm(M)
@@ -62,23 +62,26 @@ class AxisAlign(TorchTransformerMixin):
62
62
  if self.sort_method == "count":
63
63
  sort_metric = torch.bincount(idx, minlength=d)
64
64
  elif self.sort_method == "norm":
65
- sort_metric = torch.linalg.norm(X @ self.R.mT, p=2, dim=0)
65
+ sort_metric = torch.linalg.norm(X @ self.R.mT, dim=0)
66
66
  else:
67
67
  raise ValueError(f"Invalid sort method {self.sort_method}.")
68
68
 
69
69
  self.R = self.R[torch.argsort(sort_metric, dim=0, descending=True)]
70
70
  return self
71
71
 
72
- def transform(self, X: torch.Tensor, hard: bool = False) -> torch.Tensor:
72
+ def transform(self, X: torch.Tensor, normalize: bool = True, hard: bool = False) -> torch.Tensor:
73
73
  """
74
74
  Args:
75
75
  X (torch.Tensor): continuous eigenvectors from NCUT, shape (n, k)
76
+ normalize (bool): whether to normalize input features before rotating
76
77
  hard (bool): whether to return cluster indices of input features or just the rotated features
77
78
  Returns:
78
79
  torch.Tensor: Discretized eigenvectors, shape (n, k), each row is a one-hot vector.
79
80
  """
81
+ if normalize:
82
+ X = Fn.normalize(X, p=2, dim=1)
80
83
  rotated_X = X @ self.R.mT
81
84
  return torch.argmax(rotated_X, dim=1) if hard else rotated_X
82
85
 
83
- def fit_transform(self, X: torch.Tensor, hard: bool = False) -> torch.Tensor:
84
- return self.fit(X).transform(X, hard=hard)
86
+ def fit_transform(self, X: torch.Tensor, normalize: bool = True, hard: bool = False) -> torch.Tensor:
87
+ return self.fit(X).transform(X, normalize=normalize, hard=hard)
@@ -2,10 +2,9 @@ from abc import abstractmethod
2
2
  from typing import Any
3
3
 
4
4
  import torch
5
- from sklearn.base import TransformerMixin, BaseEstimator
6
5
 
7
6
 
8
- class TorchTransformerMixin(TransformerMixin, BaseEstimator):
7
+ class TorchTransformerMixin:
9
8
  """Mixin class for all transformers in scikit-learn.
10
9
 
11
10
  This mixin defines the following functionality:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: nystrom_ncut
3
- Version: 0.1.9
3
+ Version: 0.2.0
4
4
  Summary: Normalized Cut and Nyström Approximation
5
5
  Author-email: Huzheng Yang <huze.yann@gmail.com>, Wentinn Liao <wentinn.liao@gmail.com>
6
6
  Project-URL: Documentation, https://github.com/JophiArcana/Nystrom-NCUT/
@@ -6,13 +6,13 @@ nystrom_ncut/sampling_utils.py,sha256=oMmhFcd_N_D15Ht7F0rCGPSgLeitJszAKMD3ICKwHN
6
6
  nystrom_ncut/visualize_utils.py,sha256=d3VXjzJPZPPyUMg_b8hKLQoBaRWvutu6u7l36S2gmIM,23007
7
7
  nystrom_ncut/nystrom/__init__.py,sha256=lAoO00i4FG5xqGKDO_OYcSvO4qPK64x_X_hDNBvuLUc,105
8
8
  nystrom_ncut/nystrom/distance_realization.py,sha256=InajllGtRVnLVlZoipZNbHFTGHaTs3zxizKe3kI2Los,5815
9
- nystrom_ncut/nystrom/normalized_cut.py,sha256=5aR-CbRAWQVOA1FlQCuxSKEik9tR9sNLsJVBA7_LXyE,5905
9
+ nystrom_ncut/nystrom/normalized_cut.py,sha256=2ocwc4U3A6GGFs0cuL0DO1yNvt59SJ3uDtj00U0foPM,5906
10
10
  nystrom_ncut/nystrom/nystrom_utils.py,sha256=5w-2GAMb7b6ArZdPEnAnKPFFrsbHSfC-S78cvrR6O20,12806
11
11
  nystrom_ncut/transformer/__init__.py,sha256=jjXjcNp3LrxeF6mqG9VY5k3asrqaY6bXzJz6wTpH78Q,105
12
- nystrom_ncut/transformer/axis_align.py,sha256=pX7wk4O6fj-CwRv1TYyPYXsTmmXUtQ5q0c5fDQBVE6Q,3068
13
- nystrom_ncut/transformer/transformer_mixin.py,sha256=fTNtDFYPw2Fc8mjvK2xNHOw5mCkbO0usUpOnnJdyr5M,1743
14
- nystrom_ncut-0.1.9.dist-info/LICENSE,sha256=2bm9uFabQZ3Ykb_SaSU_uUbAj2-htc6WJQmS_65qD00,1073
15
- nystrom_ncut-0.1.9.dist-info/METADATA,sha256=8ez3ayc8UcBR8R8Ds7nRAKbrEa3766WNDrQwXToQ9ZM,6058
16
- nystrom_ncut-0.1.9.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
17
- nystrom_ncut-0.1.9.dist-info/top_level.txt,sha256=gM8IWWHYysIRTCvCTcdS4RShOyl9pxpylgSwPUZR2XM,22
18
- nystrom_ncut-0.1.9.dist-info/RECORD,,
12
+ nystrom_ncut/transformer/axis_align.py,sha256=UHPbZVs3XFDvxAQHJC2La8W534k3nJ776DEfQeJVKxg,3297
13
+ nystrom_ncut/transformer/transformer_mixin.py,sha256=YAjrDWTL5Hjnk9J2OsoxvtwT2N0u8IdgMSx0rRFmZzE,1653
14
+ nystrom_ncut-0.2.0.dist-info/LICENSE,sha256=2bm9uFabQZ3Ykb_SaSU_uUbAj2-htc6WJQmS_65qD00,1073
15
+ nystrom_ncut-0.2.0.dist-info/METADATA,sha256=rKIYYECCawfOcjVLx6OZJAo5ngi2uKriNLbgKVEs_IY,6058
16
+ nystrom_ncut-0.2.0.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
17
+ nystrom_ncut-0.2.0.dist-info/top_level.txt,sha256=gM8IWWHYysIRTCvCTcdS4RShOyl9pxpylgSwPUZR2XM,22
18
+ nystrom_ncut-0.2.0.dist-info/RECORD,,