nystrom-ncut 0.2.2__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nystrom_ncut/common.py CHANGED
@@ -12,8 +12,8 @@ def ceildiv(a: int, b: int) -> int:
12
12
  def lazy_normalize(x: torch.Tensor, n: int = 1000, **normalize_kwargs: Any) -> torch.Tensor:
13
13
  numel = np.prod(x.shape[:-1])
14
14
  n = min(n, numel)
15
- random_indices = torch.randperm(numel)[:n]
16
- _x = x.flatten(0, -2)[random_indices]
15
+ random_indices = torch.randperm(numel, device=x.device)[:n]
16
+ _x = x.view((-1, x.shape[-1]))[random_indices]
17
17
  if torch.allclose(torch.norm(_x, **normalize_kwargs), torch.ones(n, device=x.device)):
18
18
  return x
19
19
  else:
@@ -21,13 +21,14 @@ def lazy_normalize(x: torch.Tensor, n: int = 1000, **normalize_kwargs: Any) -> t
21
21
 
22
22
 
23
23
  def quantile_min_max(x: torch.Tensor, q1: float, q2: float, n_sample: int = 10000):
24
- if x.shape[0] > n_sample:
24
+ x = x.flatten()
25
+ if len(x) > n_sample:
25
26
  np.random.seed(0)
26
- random_idx = np.random.choice(x.shape[0], n_sample, replace=False)
27
+ random_idx = np.random.choice(len(x), n_sample, replace=False)
27
28
  vmin, vmax = x[random_idx].quantile(q1), x[random_idx].quantile(q2)
28
29
  else:
29
30
  vmin, vmax = x.quantile(q1), x.quantile(q2)
30
- return vmin, vmax
31
+ return vmin.item(), vmax.item()
31
32
 
32
33
 
33
34
  def quantile_normalize(x: torch.Tensor, q: float = 0.95):
@@ -57,5 +58,17 @@ def quantile_normalize(x: torch.Tensor, q: float = 0.95):
57
58
  return x
58
59
 
59
60
 
61
+ class default_device:
62
+ def __init__(self, device: torch.device):
63
+ self._device = device
64
+
65
+ def __enter__(self):
66
+ self._original_device = torch.get_default_device()
67
+ torch.set_default_device(self._device)
68
+
69
+ def __exit__(self, exc_type, exc_val, exc_tb):
70
+ torch.set_default_device(self._original_device)
71
+
72
+
60
73
  def profile(name: str, t: torch.Tensor) -> None:
61
74
  print(f"{name} --- nan: {t.isnan().any()}, inf: {t.isinf().any()}, max: {t.abs().max()}, min: {t.abs().min()}")
@@ -1,61 +1,71 @@
1
- from typing import Literal
1
+ import collections
2
+ from typing import List, Literal, OrderedDict
2
3
 
3
4
  import torch
4
5
 
5
6
  from .common import lazy_normalize
6
7
 
7
8
 
8
- DistanceOptions = Literal["cosine", "euclidean", "rbf"]
9
+ DistanceOptions = Literal["cosine", "euclidean"]
10
+ AffinityOptions = Literal["cosine", "rbf", "laplacian"]
9
11
 
12
+ # noinspection PyTypeChecker
13
+ DISTANCE_TO_AFFINITY: OrderedDict[DistanceOptions, List[AffinityOptions]] = collections.OrderedDict([
14
+ ("cosine", ["cosine"]),
15
+ ("euclidean", ["rbf", "laplacian"]),
16
+ ])
17
+ # noinspection PyTypeChecker
18
+ AFFINITY_TO_DISTANCE: OrderedDict[AffinityOptions, DistanceOptions] = collections.OrderedDict(sum([
19
+ [(affinity_type, distance_type) for affinity_type in affinity_types]
20
+ for distance_type, affinity_types in DISTANCE_TO_AFFINITY.items()
21
+ ], start=[]))
10
22
 
11
- def to_euclidean(x: torch.Tensor, disttype: DistanceOptions) -> torch.Tensor:
12
- if disttype == "cosine":
23
+
24
+
25
+ def to_euclidean(x: torch.Tensor, distance_type: DistanceOptions) -> torch.Tensor:
26
+ if distance_type == "cosine":
13
27
  return lazy_normalize(x, p=2, dim=-1)
14
- elif disttype == "rbf":
28
+ elif distance_type == "euclidean":
15
29
  return x
16
30
  else:
17
- raise ValueError(f"to_euclidean not implemented for disttype {disttype}.")
31
+ raise ValueError(f"to_euclidean not implemented for distance_type {distance_type}.")
18
32
 
19
33
 
20
34
  def distance_from_features(
21
35
  features: torch.Tensor,
22
36
  features_B: torch.Tensor,
23
- distance: DistanceOptions,
37
+ distance_type: DistanceOptions,
24
38
  ):
25
- """Compute affinity matrix from input features.
39
+ """Compute distance matrix from input features.
26
40
  Args:
27
41
  features (torch.Tensor): input features, shape (n_samples, n_features)
28
42
  features_B (torch.Tensor, optional): optional, if not None, compute affinity between two features
29
- distance (str): distance metric, 'cosine' (default) or 'euclidean', 'rbf'.
43
+ distance_type (str): distance metric, 'cosine' (default) or 'euclidean', 'rbf'.
30
44
  Returns:
31
45
  (torch.Tensor): affinity matrix, shape (n_samples, n_samples)
32
46
  """
33
47
  # compute distance matrix from input features
34
- if distance == "cosine":
35
- features = lazy_normalize(features, dim=-1)
36
- features_B = lazy_normalize(features_B, dim=-1)
37
- D = 1 - features @ features_B.T
38
- elif distance == "euclidean":
39
- D = torch.cdist(features, features_B, p=2)
40
- elif distance == "rbf":
41
- D = 0.5 * torch.cdist(features, features_B, p=2) ** 2
42
-
43
- # Outlier-robust scale invariance using quantiles to estimate standard deviation
44
- c = 2.0
45
- p = torch.erf(torch.tensor((-c, c), device=features.device) * (2 ** -0.5))
46
- stds = torch.quantile(features, q=(p + 1) / 2, dim=0)
47
- stds = (stds[1] - stds[0]) / (2 * c)
48
- D = D / (torch.linalg.norm(stds) ** 2)
49
- else:
50
- raise ValueError("distance should be 'cosine' or 'euclidean', 'rbf'")
51
- return D
48
+ shape: torch.Size = features.shape[:-2]
49
+ features = features.view((-1, *features.shape[-2:]))
50
+ features_B = features_B.view((-1, *features_B.shape[-2:]))
51
+
52
+ match distance_type:
53
+ case "cosine":
54
+ features = lazy_normalize(features, dim=-1)
55
+ features_B = lazy_normalize(features_B, dim=-1)
56
+ D = 1 - features @ features_B.mT
57
+ case "euclidean":
58
+ D = torch.cdist(features, features_B, p=2)
59
+ case _:
60
+ raise ValueError("Distance should be 'cosine' or 'euclidean'")
61
+ return D.view((*shape, *D.shape[-2:]))
52
62
 
53
63
 
54
64
  def affinity_from_features(
55
65
  features: torch.Tensor,
56
66
  features_B: torch.Tensor = None,
57
67
  affinity_focal_gamma: float = 1.0,
58
- distance: DistanceOptions = "cosine",
68
+ affinity_type: AffinityOptions = "cosine",
59
69
  ):
60
70
  """Compute affinity matrix from input features.
61
71
 
@@ -64,7 +74,7 @@ def affinity_from_features(
64
74
  features_B (torch.Tensor, optional): optional, if not None, compute affinity between two features
65
75
  affinity_focal_gamma (float): affinity matrix parameter, lower t reduce the edge weights
66
76
  on weak connections, default 1.0
67
- distance (str): distance metric, 'cosine' (default) or 'euclidean', 'rbf'.
77
+ affinity_type (str): distance metric, 'cosine' (default) or 'euclidean'.
68
78
  Returns:
69
79
  (torch.Tensor): affinity matrix, shape (n_samples, n_samples)
70
80
  """
@@ -75,9 +85,21 @@ def affinity_from_features(
75
85
  features_B = features if features_B is None else features_B
76
86
 
77
87
  # compute distance matrix from input features
78
- D = distance_from_features(features, features_B, distance)
88
+ D = distance_from_features(features, features_B, AFFINITY_TO_DISTANCE[affinity_type])
79
89
 
80
- # torch.exp make affinity matrix positive definite,
81
90
  # lower affinity_focal_gamma reduce the weak edge weights
82
- A = torch.exp(-D / affinity_focal_gamma)
91
+ match affinity_type:
92
+ case "cosine" | "laplacian":
93
+ A = torch.exp(-D / affinity_focal_gamma) # [... x n x n]
94
+ case "rbf":
95
+ # Outlier-robust scale invariance using quantiles to estimate standard deviation
96
+ c = 2.0
97
+ p = torch.erf(torch.tensor((-c, c), device=features.device) * (2 ** -0.5))
98
+ stds = torch.nanquantile(features, q=(p + 1) / 2, dim=-2) # [2 x ... x d]
99
+ stds = (stds[1] - stds[0]) / (2 * c) # [... x d]
100
+ D = 0.5 * (D / torch.norm(stds, dim=-1)[..., None, None]) ** 2
101
+ A = torch.exp(-D / affinity_focal_gamma)
102
+ case _:
103
+ raise ValueError("Affinity should be 'cosine', 'rbf', or 'laplacian'")
104
+
83
105
  return A
@@ -1,6 +1,3 @@
1
- from .distance_realization import (
2
- DistanceRealization,
3
- )
4
1
  from .normalized_cut import (
5
2
  NCut,
6
3
  )
@@ -18,10 +18,10 @@ from ..sampling_utils import (
18
18
  class GramKernel(OnlineKernel):
19
19
  def __init__(
20
20
  self,
21
- distance: DistanceOptions,
21
+ distance_type: DistanceOptions,
22
22
  eig_solver: EigSolverOptions,
23
23
  ):
24
- self.distance: DistanceOptions = distance
24
+ self.distance_type: DistanceOptions = distance_type
25
25
  self.eig_solver: EigSolverOptions = eig_solver
26
26
 
27
27
  # Anchor matrices
@@ -40,7 +40,7 @@ class GramKernel(OnlineKernel):
40
40
  self.A = -0.5 * distance_from_features(
41
41
  self.anchor_features, # [n x d]
42
42
  self.anchor_features,
43
- distance=self.distance,
43
+ distance_type=self.distance_type,
44
44
  ) # [n x n]
45
45
  d = features.shape[-1]
46
46
  U, L = solve_eig(
@@ -58,7 +58,7 @@ class GramKernel(OnlineKernel):
58
58
  B = -0.5 * distance_from_features(
59
59
  self.anchor_features, # [n x d]
60
60
  features, # [m x d]
61
- distance=self.distance,
61
+ distance_type=self.distance_type,
62
62
  ) # [n x m]
63
63
  b_r = torch.sum(B, dim=-1) # [n]
64
64
  b_c = torch.sum(B, dim=-2) # [m]
@@ -84,7 +84,7 @@ class GramKernel(OnlineKernel):
84
84
  B = -0.5 * distance_from_features(
85
85
  self.anchor_features,
86
86
  features,
87
- distance=self.distance,
87
+ distance_type=self.distance_type,
88
88
  )
89
89
  b_c = torch.sum(B, dim=-2) # [m]
90
90
  col_sum = b_c + B.mT @ self.Ainv @ self.b_r # [m]
@@ -98,7 +98,7 @@ class DistanceRealization(OnlineNystromSubsampleFit):
98
98
  def __init__(
99
99
  self,
100
100
  n_components: int = 100,
101
- distance: DistanceOptions = "cosine",
101
+ distance_type: DistanceOptions = "cosine",
102
102
  sample_config: SampleConfig = SampleConfig(),
103
103
  eig_solver: EigSolverOptions = "svd_lowrank",
104
104
  chunk_size: int = 8192,
@@ -115,13 +115,12 @@ class DistanceRealization(OnlineNystromSubsampleFit):
115
115
  OnlineNystromSubsampleFit.__init__(
116
116
  self,
117
117
  n_components=n_components,
118
- kernel=GramKernel(distance, eig_solver),
119
- distance=distance,
118
+ kernel=GramKernel(distance_type, eig_solver),
119
+ distance_type=distance_type,
120
120
  sample_config=sample_config,
121
121
  eig_solver=eig_solver,
122
122
  chunk_size=chunk_size,
123
123
  )
124
- self.distance: DistanceOptions = distance
125
124
 
126
125
  def fit_transform(
127
126
  self,
@@ -8,7 +8,8 @@ from .nystrom_utils import (
8
8
  solve_eig,
9
9
  )
10
10
  from ..distance_utils import (
11
- DistanceOptions,
11
+ AffinityOptions,
12
+ AFFINITY_TO_DISTANCE,
12
13
  affinity_from_features,
13
14
  )
14
15
  from ..sampling_utils import (
@@ -20,80 +21,83 @@ class LaplacianKernel(OnlineKernel):
20
21
  def __init__(
21
22
  self,
22
23
  affinity_focal_gamma: float,
23
- distance: DistanceOptions,
24
+ affinity_type: AffinityOptions,
24
25
  adaptive_scaling: bool,
25
26
  eig_solver: EigSolverOptions,
26
27
  ):
27
28
  self.affinity_focal_gamma = affinity_focal_gamma
28
- self.distance: DistanceOptions = distance
29
+ self.affinity_type: AffinityOptions = affinity_type
29
30
  self.adaptive_scaling: bool = adaptive_scaling
30
31
  self.eig_solver: EigSolverOptions = eig_solver
31
32
 
32
33
  # Anchor matrices
33
- self.anchor_features: torch.Tensor = None # [n x d]
34
- self.A: torch.Tensor = None # [n x n]
35
- self.Ainv: torch.Tensor = None # [n x n]
34
+ self.anchor_features: torch.Tensor = None # [... x n x d]
35
+ self.anchor_mask: torch.Tensor = None
36
+ self.A: torch.Tensor = None # [... x n x n]
37
+ self.Ainv: torch.Tensor = None # [... x n x n]
36
38
 
37
39
  # Updated matrices
38
- self.a_r: torch.Tensor = None # [n]
39
- self.b_r: torch.Tensor = None # [n]
40
+ self.a_r: torch.Tensor = None # [... x n]
41
+ self.b_r: torch.Tensor = None # [... x n]
40
42
 
41
43
  def fit(self, features: torch.Tensor) -> None:
42
- self.anchor_features = features # [n x d]
43
- self.A = affinity_from_features(
44
- self.anchor_features, # [n x d]
44
+ self.anchor_features = features # [... x n x d]
45
+ self.anchor_mask = torch.all(torch.isnan(self.anchor_features), dim=-1) # [... x n]
46
+
47
+ self.A = torch.nan_to_num(affinity_from_features(
48
+ self.anchor_features, # [... x n x d]
45
49
  affinity_focal_gamma=self.affinity_focal_gamma,
46
- distance=self.distance,
47
- ) # [n x n]
50
+ affinity_type=self.affinity_type,
51
+ ), nan=0.0) # [... x n x n]
48
52
  d = features.shape[-1]
49
53
  U, L = solve_eig(
50
54
  self.A,
51
55
  num_eig=d + 1, # d * (d + 3) // 2 + 1,
52
56
  eig_solver=self.eig_solver,
53
- ) # [n x (d + 1)], [d + 1]
54
- self.Ainv = U @ torch.diag(1 / L) @ U.mT # [n x n]
55
- self.a_r = torch.sum(self.A, dim=-1) # [n]
56
- self.b_r = torch.zeros_like(self.a_r) # [n]
57
+ ) # [... x n x (d + 1)], [... x (d + 1)]
58
+ self.Ainv = U @ torch.diag_embed(1 / L) @ U.mT # [... x n x n]
59
+ self.a_r = torch.where(self.anchor_mask, torch.inf, torch.sum(self.A, dim=-1)) # [... x n]
60
+ self.b_r = torch.zeros_like(self.a_r) # [... x n]
57
61
 
58
62
  def _affinity(self, features: torch.Tensor) -> torch.Tensor:
59
- B = affinity_from_features(
60
- self.anchor_features, # [n x d]
61
- features, # [m x d]
63
+ B = torch.where(self.anchor_mask[..., None], 0.0, affinity_from_features(
64
+ self.anchor_features, # [... x n x d]
65
+ features, # [... x m x d]
62
66
  affinity_focal_gamma=self.affinity_focal_gamma,
63
- distance=self.distance,
64
- ) # [n x m]
67
+ affinity_type=self.affinity_type,
68
+ )) # [... x n x m]
65
69
  if self.adaptive_scaling:
66
70
  diagonal = (
67
- einops.rearrange(B, "n m -> m 1 n") # [m x 1 x n]
68
- @ self.Ainv # [n x n]
69
- @ einops.rearrange(B, "n m -> m n 1") # [m x n x 1]
70
- ).squeeze(1, 2) # [m]
71
- adaptive_scale = diagonal ** -0.5 # [m]
72
- B = B * adaptive_scale
73
- return B # [n x m]
71
+ einops.rearrange(B, "... n m -> ... m 1 n") # [... x m x 1 x n]
72
+ @ self.Ainv # [... x n x n]
73
+ @ einops.rearrange(B, "... n m -> ... m n 1") # [... x m x n x 1]
74
+ ).squeeze(-2, -1) # [... x m]
75
+ adaptive_scale = diagonal ** -0.5 # [... x m]
76
+ B = B * adaptive_scale[..., None, :]
77
+ return B # [... x n x m]
74
78
 
75
79
  def update(self, features: torch.Tensor) -> torch.Tensor:
76
- B = self._affinity(features) # [n x m]
77
- b_r = torch.sum(B, dim=-1) # [n]
78
- b_c = torch.sum(B, dim=-2) # [m]
79
- self.b_r = self.b_r + b_r # [n]
80
+ B = self._affinity(features) # [... x n x m]
81
+ b_r = torch.sum(torch.nan_to_num(B, nan=0.0), dim=-1) # [... x n]
82
+ b_c = torch.sum(B, dim=-2) # [... x m]
83
+ self.b_r = self.b_r + b_r # [... x n]
80
84
 
81
- row_sum = self.a_r + self.b_r # [n]
82
- col_sum = b_c + B.mT @ self.Ainv @ self.b_r # [m]
83
- scale = (row_sum[:, None] * col_sum) ** -0.5 # [n x m]
84
- return (B * scale).mT # [m x n]
85
+ row_sum = self.a_r + self.b_r # [... x n]
86
+ col_sum = b_c + (B.mT @ (self.Ainv @ self.b_r[..., None]))[..., 0] # [... x m]
87
+ scale = (row_sum[..., :, None] * col_sum[..., None, :]) ** -0.5 # [... x n x m]
88
+ return (B * scale).mT # [... x m x n]
85
89
 
86
90
  def transform(self, features: torch.Tensor = None) -> torch.Tensor:
87
- row_sum = self.a_r + self.b_r # [n]
91
+ row_sum = self.a_r + self.b_r # [... x n]
88
92
  if features is None:
89
- B = self.A # [n x n]
90
- col_sum = row_sum # [n]
93
+ B = self.A # [... x n x n]
94
+ col_sum = row_sum # [... x n]
91
95
  else:
92
96
  B = self._affinity(features)
93
- b_c = torch.sum(B, dim=-2) # [m]
94
- col_sum = b_c + B.mT @ self.Ainv @ self.b_r # [m]
95
- scale = (row_sum[:, None] * col_sum) ** -0.5 # [n x m]
96
- return (B * scale).mT # [m x n]
97
+ b_c = torch.sum(B, dim=-2) # [... x m]
98
+ col_sum = b_c + (B.mT @ (self.Ainv @ self.b_r[..., None]))[..., 0] # [... x m]
99
+ scale = (row_sum[..., :, None] * col_sum[..., None, :]) ** -0.5 # [... x n x m]
100
+ return (B * scale).mT # [... x m x n]
97
101
 
98
102
 
99
103
  class NCut(OnlineNystromSubsampleFit):
@@ -103,7 +107,7 @@ class NCut(OnlineNystromSubsampleFit):
103
107
  self,
104
108
  n_components: int = 100,
105
109
  affinity_focal_gamma: float = 1.0,
106
- distance: DistanceOptions = "cosine",
110
+ affinity_type: AffinityOptions = "cosine",
107
111
  adaptive_scaling: bool = False,
108
112
  sample_config: SampleConfig = SampleConfig(),
109
113
  eig_solver: EigSolverOptions = "svd_lowrank",
@@ -124,8 +128,8 @@ class NCut(OnlineNystromSubsampleFit):
124
128
  OnlineNystromSubsampleFit.__init__(
125
129
  self,
126
130
  n_components=n_components,
127
- kernel=LaplacianKernel(affinity_focal_gamma, distance, adaptive_scaling, eig_solver),
128
- distance=distance,
131
+ kernel=LaplacianKernel(affinity_focal_gamma, affinity_type, adaptive_scaling, eig_solver),
132
+ distance_type=AFFINITY_TO_DISTANCE[affinity_type],
129
133
  sample_config=sample_config,
130
134
  eig_solver=eig_solver,
131
135
  chunk_size=chunk_size,
@@ -25,15 +25,15 @@ EigSolverOptions = Literal["svd_lowrank", "lobpcg", "svd", "eigh"]
25
25
 
26
26
  class OnlineKernel:
27
27
  @abstractmethod
28
- def fit(self, features: torch.Tensor) -> "OnlineKernel": # [n x d]
28
+ def fit(self, features: torch.Tensor) -> "OnlineKernel": # [... x n x d]
29
29
  """"""
30
30
 
31
31
  @abstractmethod
32
- def update(self, features: torch.Tensor) -> torch.Tensor: # [m x d] -> [m x n]
32
+ def update(self, features: torch.Tensor) -> torch.Tensor: # [... x m x d] -> [... x m x n]
33
33
  """"""
34
34
 
35
35
  @abstractmethod
36
- def transform(self, features: torch.Tensor = None) -> torch.Tensor: # [m x d] -> [m x n]
36
+ def transform(self, features: torch.Tensor = None) -> torch.Tensor: # [... x m x d] -> [... x m x n]
37
37
  """"""
38
38
 
39
39
 
@@ -54,20 +54,21 @@ class OnlineNystrom(TorchTransformerMixin):
54
54
  self.n_components: int = n_components
55
55
  self.kernel: OnlineKernel = kernel
56
56
  self.eig_solver: EigSolverOptions = eig_solver
57
+ self.shape: torch.Size = None # ...
57
58
 
58
59
  self.chunk_size = chunk_size
59
60
 
60
61
  # Anchor matrices
61
- self.anchor_features: torch.Tensor = None # [n x d]
62
- self.A: torch.Tensor = None # [n x n]
63
- self.Ahinv: torch.Tensor = None # [n x n]
64
- self.Ahinv_UL: torch.Tensor = None # [n x indirect_pca_dim]
65
- self.Ahinv_VT: torch.Tensor = None # [indirect_pca_dim x n]
62
+ self.anchor_features: torch.Tensor = None # [... x n x d]
63
+ self.A: torch.Tensor = None # [... x n x n]
64
+ self.Ahinv: torch.Tensor = None # [... x n x n]
65
+ self.Ahinv_UL: torch.Tensor = None # [... x n x indirect_pca_dim]
66
+ self.Ahinv_VT: torch.Tensor = None # [... x indirect_pca_dim x n]
66
67
 
67
68
  # Updated matrices
68
- self.S: torch.Tensor = None # [n x n]
69
- self.transform_matrix: torch.Tensor = None # [n x n_components]
70
- self.eigenvalues_: torch.Tensor = None # [n]
69
+ self.S: torch.Tensor = None # [... x n x n]
70
+ self.transform_matrix: torch.Tensor = None # [... x n x n_components]
71
+ self.eigenvalues_: torch.Tensor = None # [... x n]
71
72
 
72
73
  def _update_to_kernel(self, d: int) -> Tuple[torch.Tensor, torch.Tensor]:
73
74
  self.A = self.S = self.kernel.transform()
@@ -75,10 +76,10 @@ class OnlineNystrom(TorchTransformerMixin):
75
76
  self.A,
76
77
  num_eig=d + 1, # d * (d + 3) // 2 + 1,
77
78
  eig_solver=self.eig_solver,
78
- ) # [n x (? + 1)], [? + 1]
79
- self.Ahinv_UL = U * (L ** -0.5) # [n x (? + 1)]
80
- self.Ahinv_VT = U.mT # [(? + 1) x n]
81
- self.Ahinv = self.Ahinv_UL @ self.Ahinv_VT # [n x n]
79
+ ) # [... x n x (? + 1)], [... x (? + 1)]
80
+ self.Ahinv_UL = U * (L[..., None, :] ** -0.5) # [... x n x (? + 1)]
81
+ self.Ahinv_VT = U.mT # [... x (? + 1) x n]
82
+ self.Ahinv = self.Ahinv_UL @ self.Ahinv_VT # [... x n x n]
82
83
  return U, L
83
84
 
84
85
  def fit(self, features: torch.Tensor) -> "OnlineNystrom":
@@ -89,65 +90,63 @@ class OnlineNystrom(TorchTransformerMixin):
89
90
  self.anchor_features = features
90
91
 
91
92
  self.kernel.fit(self.anchor_features)
92
- U, L = self._update_to_kernel(features.shape[-1]) # [n x (d + 1)], [d + 1]
93
+ U, L = self._update_to_kernel(features.shape[-1]) # [... x n x (d + 1)], [... x (d + 1)]
93
94
 
94
- self.transform_matrix = (U / L)[:, :self.n_components] # [n x n_components]
95
- self.eigenvalues_ = L[:self.n_components] # [n_components]
96
- self.is_fitted = True
97
- return U[:, :self.n_components] # [n x n_components]
95
+ self.transform_matrix = (U / L[..., None, :])[..., :, :self.n_components] # [... x n x n_components]
96
+ self.eigenvalues_ = L[..., :self.n_components] # [... x n_components]
97
+ return U[..., :, :self.n_components] # [... x n x n_components]
98
98
 
99
99
  def update(self, features: torch.Tensor) -> torch.Tensor:
100
100
  d = features.shape[-1]
101
- n_chunks = ceildiv(len(features), self.chunk_size)
101
+ n_chunks = ceildiv(features.shape[-2], self.chunk_size)
102
102
  if n_chunks > 1:
103
103
  """ Chunked version """
104
- chunks = torch.chunk(features, n_chunks, dim=0)
104
+ chunks = torch.chunk(features, n_chunks, dim=-2)
105
105
  for chunk in chunks:
106
106
  self.kernel.update(chunk)
107
107
  self._update_to_kernel(d)
108
108
 
109
- compressed_BBT = 0.0 # [(? + 1) x (? + 1))]
109
+ compressed_BBT = 0.0 # [... x (? + 1) x (? + 1))]
110
110
  for chunk in chunks:
111
- _B = self.kernel.transform(chunk).mT # [n x _m]
112
- _compressed_B = self.Ahinv_VT @ _B # [(? + 1) x _m]
113
- compressed_BBT = compressed_BBT + _compressed_B @ _compressed_B.mT # [(? + 1) x (? + 1)]
114
- self.S = self.S + self.Ahinv_UL @ compressed_BBT @ self.Ahinv_UL.mT # [n x n]
115
- US, self.eigenvalues_ = solve_eig(self.S, self.n_components, self.eig_solver) # [n x n_components], [n_components]
116
- self.transform_matrix = self.Ahinv @ US * (self.eigenvalues_ ** -0.5) # [n x n_components]
111
+ _B = self.kernel.transform(chunk).mT # [... x n x _m]
112
+ _compressed_B = self.Ahinv_VT @ _B # [... x (? + 1) x _m]
113
+ _compressed_B = torch.nan_to_num(_compressed_B, nan=0.0)
114
+ compressed_BBT = compressed_BBT + _compressed_B @ _compressed_B.mT # [... x (? + 1) x (? + 1)]
115
+ self.S = self.S + self.Ahinv_UL @ compressed_BBT @ self.Ahinv_UL.mT # [... x n x n]
116
+ US, self.eigenvalues_ = solve_eig(self.S, self.n_components, self.eig_solver) # [... x n x n_components], [... x n_components]
117
+ self.transform_matrix = self.Ahinv @ US * (self.eigenvalues_[..., None, :] ** -0.5) # [... x n x n_components]
117
118
 
118
119
  VS = []
119
120
  for chunk in chunks:
120
- VS.append(self.kernel.transform(chunk) @ self.transform_matrix) # [_m x n_components]
121
- VS = torch.cat(VS, dim=0)
122
- return VS # [m x n_components]
121
+ VS.append(self.kernel.transform(chunk) @ self.transform_matrix) # [... x _m x n_components]
122
+ VS = torch.cat(VS, dim=-2)
123
+ return VS # [... x m x n_components]
123
124
  else:
124
125
  """ Unchunked version """
125
- B = self.kernel.update(features).mT # [n x m]
126
+ B = self.kernel.update(features).mT # [... x n x m]
126
127
  self._update_to_kernel(d)
127
- compressed_B = self.Ahinv_VT @ B # [indirect_pca_dim x m]
128
+ compressed_B = self.Ahinv_VT @ B # [... x (? + 1) x m]
129
+ compressed_B = torch.nan_to_num(compressed_B, nan=0.0)
128
130
 
129
- self.S = self.S + self.Ahinv_UL @ (compressed_B @ compressed_B.mT) @ self.Ahinv_UL.mT # [n x n]
130
- US, self.eigenvalues_ = solve_eig(self.S, self.n_components, self.eig_solver) # [n x n_components], [n_components]
131
- self.transform_matrix = self.Ahinv @ US * (self.eigenvalues_ ** -0.5) # [n x n_components]
131
+ self.S = self.S + self.Ahinv_UL @ (compressed_B @ compressed_B.mT) @ self.Ahinv_UL.mT # [... x n x n]
132
+ US, self.eigenvalues_ = solve_eig(self.S, self.n_components, self.eig_solver) # [... x n x n_components], [... x n_components]
133
+ self.transform_matrix = self.Ahinv @ US * (self.eigenvalues_[..., None, :] ** -0.5) # [... x n x n_components]
132
134
 
133
- return B.mT @ self.transform_matrix # [m x n_components]
135
+ return B.mT @ self.transform_matrix # [... x m x n_components]
134
136
 
135
- def transform(self, features: torch.Tensor = None) -> torch.Tensor:
136
- if features is None:
137
- VS = self.A @ self.transform_matrix # [n x n_components]
137
+ def transform(self, features: torch.Tensor) -> torch.Tensor:
138
+ n_chunks = ceildiv(features.shape[-2], self.chunk_size)
139
+ if n_chunks > 1:
140
+ """ Chunked version """
141
+ chunks = torch.chunk(features, n_chunks, dim=-2)
142
+ VS = []
143
+ for chunk in chunks:
144
+ VS.append(self.kernel.transform(chunk) @ self.transform_matrix) # [... x _m x n_components]
145
+ VS = torch.cat(VS, dim=-2)
138
146
  else:
139
- n_chunks = ceildiv(len(features), self.chunk_size)
140
- if n_chunks > 1:
141
- """ Chunked version """
142
- chunks = torch.chunk(features, n_chunks, dim=0)
143
- VS = []
144
- for chunk in chunks:
145
- VS.append(self.kernel.transform(chunk) @ self.transform_matrix) # [_m x n_components]
146
- VS = torch.cat(VS, dim=0)
147
- else:
148
- """ Unchunked version """
149
- VS = self.kernel.transform(features) @ self.transform_matrix # [m x n_components]
150
- return VS # [m x n_components]
147
+ """ Unchunked version """
148
+ VS = self.kernel.transform(features) @ self.transform_matrix # [... x m x n_components]
149
+ return VS # [... x m x n_components]
151
150
 
152
151
 
153
152
  class OnlineNystromSubsampleFit(OnlineNystrom):
@@ -155,7 +154,7 @@ class OnlineNystromSubsampleFit(OnlineNystrom):
155
154
  self,
156
155
  n_components: int,
157
156
  kernel: OnlineKernel,
158
- distance: DistanceOptions,
157
+ distance_type: DistanceOptions,
159
158
  sample_config: SampleConfig,
160
159
  eig_solver: EigSolverOptions = "svd_lowrank",
161
160
  chunk_size: int = 8192,
@@ -167,7 +166,7 @@ class OnlineNystromSubsampleFit(OnlineNystrom):
167
166
  eig_solver=eig_solver,
168
167
  chunk_size=chunk_size,
169
168
  )
170
- self.distance: DistanceOptions = distance
169
+ self.distance_type: DistanceOptions = distance_type
171
170
  self.sample_config: SampleConfig = sample_config
172
171
  self.sample_config._ncut_obj = copy.deepcopy(self)
173
172
  self.anchor_indices: torch.Tensor = None
@@ -177,7 +176,7 @@ class OnlineNystromSubsampleFit(OnlineNystrom):
177
176
  features: torch.Tensor,
178
177
  precomputed_sampled_indices: torch.Tensor,
179
178
  ) -> Tuple[torch.Tensor, torch.Tensor]:
180
- _n = features.shape[0]
179
+ _n = features.shape[-2]
181
180
  if self.sample_config.num_sample >= _n:
182
181
  logging.info(
183
182
  f"NCUT nystrom num_sample is larger than number of input samples, nyström approximation is not needed, setting num_sample={_n}"
@@ -189,16 +188,17 @@ class OnlineNystromSubsampleFit(OnlineNystrom):
189
188
  else:
190
189
  self.anchor_indices = subsample_features(
191
190
  features=features,
192
- disttype=self.distance,
191
+ distance_type=self.distance_type,
193
192
  config=self.sample_config,
194
193
  )
195
- sampled_features = features[self.anchor_indices]
194
+ sampled_features = torch.gather(features, -2, self.anchor_indices[..., None].expand([-1] * self.anchor_indices.ndim + [features.shape[-1]]))
196
195
  OnlineNystrom.fit(self, sampled_features)
197
196
 
198
- _n_not_sampled = _n - len(sampled_features)
197
+ _n_not_sampled = _n - self.anchor_indices.shape[-1]
199
198
  if _n_not_sampled > 0:
200
- unsampled_indices = torch.full((_n,), True, device=features.device).scatter_(0, self.anchor_indices, False)
201
- unsampled_features = features[unsampled_indices]
199
+ unsampled_mask = torch.full(features.shape[:-1], True, device=features.device).scatter_(-1, self.anchor_indices, False)
200
+ unsampled_indices = torch.where(unsampled_mask)[-1].view((*features.shape[:-2], -1))
201
+ unsampled_features = torch.gather(features, -2, unsampled_indices[..., None].expand([-1] * unsampled_indices.ndim + [features.shape[-1]]))
202
202
  V_unsampled = OnlineNystrom.update(self, unsampled_features)
203
203
  else:
204
204
  unsampled_indices = V_unsampled = None
@@ -236,12 +236,12 @@ class OnlineNystromSubsampleFit(OnlineNystrom):
236
236
  (torch.Tensor): eigen_values, sorted in descending order, shape (num_eig,)
237
237
  """
238
238
  unsampled_indices, V_unsampled = OnlineNystromSubsampleFit._fit_helper(self, features, precomputed_sampled_indices)
239
- V_sampled = OnlineNystrom.transform(self)
239
+ V_sampled = OnlineNystrom.transform(self, self.anchor_features)
240
240
 
241
241
  if unsampled_indices is not None:
242
- V = torch.zeros((len(unsampled_indices), self.n_components), device=features.device)
243
- V[~unsampled_indices] = V_sampled
244
- V[unsampled_indices] = V_unsampled
242
+ V = torch.zeros((*features.shape[:-1], self.n_components), device=features.device)
243
+ for (indices, _V) in [(self.anchor_indices, V_sampled), (unsampled_indices, V_unsampled)]:
244
+ V.scatter_(-2, indices[..., None].expand([-1] * indices.ndim + [self.n_components]), _V)
245
245
  else:
246
246
  V = V_sampled
247
247
  return V
@@ -264,12 +264,16 @@ def solve_eig(
264
264
  (torch.Tensor): eigenvectors corresponding to the eigenvalues, shape (n_samples, num_eig)
265
265
  (torch.Tensor): eigenvalues of the eigenvectors, sorted in descending order
266
266
  """
267
- A = A + eig_value_buffer * torch.eye(A.shape[0], device=A.device)
267
+ shape: torch.Size = A.shape[:-2]
268
+ A = A.view((-1, *A.shape[-2:]))
269
+ bsz: int = A.shape[0]
270
+
271
+ A = A + eig_value_buffer * torch.eye(A.shape[-1], device=A.device)
268
272
 
269
273
  # compute eigenvectors
270
274
  if eig_solver == "svd_lowrank": # default
271
275
  # only top q eigenvectors, fastest
272
- eigen_vector, eigen_value, _ = torch.svd_lowrank(A, q=num_eig)
276
+ eigen_vector, eigen_value, _ = torch.svd_lowrank(A, q=num_eig) # complex: [(...) x N x D], [(...) x D]
273
277
  elif eig_solver == "lobpcg":
274
278
  # only top k eigenvectors, fast
275
279
  eigen_value, eigen_vector = torch.lobpcg(A, k=num_eig)
@@ -286,11 +290,15 @@ def solve_eig(
286
290
  eigen_value = eigen_value - eig_value_buffer
287
291
 
288
292
  # sort eigenvectors by eigenvalues, take top (descending order)
289
- indices = torch.topk(eigen_value.abs(), k=num_eig, dim=0).indices
290
- eigen_value, eigen_vector = eigen_value[indices], eigen_vector[:, indices]
293
+ indices = torch.topk(eigen_value.abs(), k=num_eig, dim=-1).indices # int: [(...) x S]
294
+ eigen_value = eigen_value[torch.arange(bsz)[:, None], indices] # complex: [(...) x S]
295
+ eigen_vector = eigen_vector[torch.arange(bsz)[:, None], :, indices].mT # complex: [(...) x N x S]
291
296
 
292
297
  # correct the random rotation (flipping sign) of eigenvectors
293
- sign = torch.sum(eigen_vector.real, dim=0).sign()
298
+ sign = torch.sign(torch.sum(eigen_vector.real, dim=-2, keepdim=True)) # float: [(...) x 1 x S]
294
299
  sign[sign == 0] = 1.0
295
300
  eigen_vector = eigen_vector * sign
301
+
302
+ eigen_value = eigen_value.view((*shape, *eigen_value.shape[-1:])) # complex: [... x S]
303
+ eigen_vector = eigen_vector.view((*shape, *eigen_vector.shape[-2:])) # complex: [... x N x S]
296
304
  return eigen_vector, eigen_value
@@ -1,17 +1,22 @@
1
- import logging
2
1
  from dataclasses import dataclass
3
2
  from typing import Literal
4
3
 
5
4
  import torch
6
5
  from pytorch3d.ops import sample_farthest_points
7
6
 
7
+ from .common import (
8
+ default_device,
9
+ )
8
10
  from .distance_utils import (
9
11
  DistanceOptions,
10
12
  to_euclidean,
11
13
  )
14
+ from .transformer import (
15
+ TorchTransformerMixin,
16
+ )
12
17
 
13
18
 
14
- SampleOptions = Literal["random", "fps", "fps_recursive"]
19
+ SampleOptions = Literal["full", "random", "fps", "fps_recursive"]
15
20
 
16
21
 
17
22
  @dataclass
@@ -20,69 +25,77 @@ class SampleConfig:
20
25
  num_sample: int = 10000
21
26
  fps_dim: int = 12
22
27
  n_iter: int = None
23
- _ncut_obj: object = None
28
+ _ncut_obj: TorchTransformerMixin = None
24
29
 
25
30
 
26
31
  @torch.no_grad()
27
32
  def subsample_features(
28
33
  features: torch.Tensor,
29
- disttype: DistanceOptions,
34
+ distance_type: DistanceOptions,
30
35
  config: SampleConfig,
31
- max_draw: int = 1000000,
32
36
  ):
33
- features = features.detach()
34
- if config.num_sample >= features.shape[0]:
35
- # if too many samples, use all samples and bypass Nystrom-like approximation
36
- logging.info(
37
- "num_sample is larger than total, bypass Nystrom-like approximation"
38
- )
39
- sampled_indices = torch.arange(features.shape[0])
40
- else:
41
- # sample subgraph
42
- if config.method == "fps": # default
43
- features = to_euclidean(features, disttype)
44
- if config.num_sample > max_draw:
45
- logging.warning(
46
- f"num_sample is larger than max_draw, apply farthest point sampling on random sampled {max_draw} samples"
47
- )
48
- draw_indices = torch.randperm(features.shape[0])[:max_draw]
49
- sampled_indices = fpsample(features[draw_indices], config)
50
- sampled_indices = draw_indices[sampled_indices]
51
- else:
52
- sampled_indices = fpsample(features, config)
53
-
54
- elif config.method == "random": # not recommended
55
- sampled_indices = torch.randperm(features.shape[0])[:config.num_sample]
56
-
57
- elif config.method == "fps_recursive":
58
- features = to_euclidean(features, disttype)
59
- sampled_indices = subsample_features(
60
- features=features,
61
- disttype=disttype,
62
- config=SampleConfig(method="fps", num_sample=config.num_sample, fps_dim=config.fps_dim)
63
- )
64
- nc = config._ncut_obj
65
- for _ in range(config.n_iter):
66
- fps_features, eigenvalues = nc.fit_transform(features, precomputed_sampled_indices=sampled_indices)
67
-
68
- fps_features = to_euclidean(fps_features[:, :config.fps_dim], "cosine")
69
- sampled_indices = torch.sort(fpsample(fps_features, config)).values
37
+ features = features.detach() # float: [... x n x d]
38
+ with default_device(features.device):
39
+ if config.method == "full" or config.num_sample >= features.shape[0]:
40
+ sampled_indices = torch.arange(features.shape[-2]).expand(features.shape[:-1]) # int: [... x n]
70
41
  else:
71
- raise ValueError("sample_method should be 'farthest' or 'random'")
72
- sampled_indices = torch.sort(sampled_indices).values
73
- return sampled_indices.to(features.device)
42
+ # sample
43
+ match config.method:
44
+ case "fps": # default
45
+ sampled_indices = fpsample(to_euclidean(features, distance_type), config)
46
+
47
+ case "random": # not recommended
48
+ mask = torch.all(torch.isfinite(features), dim=-1) # bool: [... x n]
49
+ weights = mask.to(torch.float) + torch.rand(mask.shape) # float: [... x n]
50
+ sampled_indices = torch.topk(weights, k=config.num_sample, dim=-1).indices # int: [... x num_sample]
51
+
52
+ case "fps_recursive":
53
+ features = to_euclidean(features, distance_type) # float: [... x n x d]
54
+ sampled_indices = subsample_features(
55
+ features=features,
56
+ distance_type=distance_type,
57
+ config=SampleConfig(method="fps", num_sample=config.num_sample, fps_dim=config.fps_dim)
58
+ ) # int: [... x num_sample]
59
+ nc = config._ncut_obj
60
+ for _ in range(config.n_iter):
61
+ fps_features, eigenvalues = nc.fit_transform(features, precomputed_sampled_indices=sampled_indices)
62
+
63
+ fps_features = to_euclidean(fps_features[:, :config.fps_dim], "cosine")
64
+ sampled_indices = torch.sort(fpsample(fps_features, config), dim=-1).values
65
+
66
+ case _:
67
+ raise ValueError("sample_method should be 'farthest' or 'random'")
68
+ sampled_indices = torch.sort(sampled_indices, dim=-1).values
69
+ return sampled_indices
74
70
 
75
71
 
76
72
  def fpsample(
77
73
  features: torch.Tensor,
78
74
  config: SampleConfig,
79
75
  ):
80
- # PCA to reduce the dimension
81
- if features.shape[1] > config.fps_dim:
82
- U, S, V = torch.pca_lowrank(features, q=config.fps_dim)
83
- features = U * S
76
+ shape = features.shape[:-2] # ...
77
+ features = features.view((-1, *features.shape[-2:])) # [(...) x n x d]
78
+ bsz = features.shape[0]
79
+
80
+ mask = torch.all(torch.isfinite(features), dim=-1) # bool: [(...) x n]
81
+ count = torch.sum(mask, dim=-1) # int: [(...)]
82
+ order = torch.topk(mask.to(torch.int), k=torch.max(count).item(), dim=-1).indices # int: [(...) x max_count]
83
+
84
+ features = torch.nan_to_num(features[torch.arange(bsz)[:, None], order], nan=0.0) # float: [(...) x max_count x d]
85
+ if features.shape[-1] > config.fps_dim:
86
+ U, S, V = torch.pca_lowrank(features, q=config.fps_dim) # float: [(...) x max_count x fps_dim], [(...) x fps_dim], [(...) x fps_dim x fps_dim]
87
+ features = U * S[..., None, :] # float: [(...) x max_count x fps_dim]
84
88
 
85
89
  try:
86
- return sample_farthest_points(features[None], K=config.num_sample)[1][0]
90
+ sample_indices = sample_farthest_points(
91
+ features, lengths=count, K=config.num_sample
92
+ )[1] # int: [(...) x num_sample]
87
93
  except RuntimeError:
88
- return sample_farthest_points(features[None].cpu(), K=config.num_sample)[1][0].to(features.device)
94
+ original_device = features.device
95
+ alternative_device = "cuda" if original_device == "cpu" else "cpu"
96
+ sample_indices = sample_farthest_points(
97
+ features.to(alternative_device), lengths=count.to(alternative_device), K=config.num_sample
98
+ )[1].to(original_device) # int: [(...) x num_sample]
99
+ sample_indices = torch.gather(order, 1, sample_indices) # int: [(...) x num_sample]
100
+
101
+ return sample_indices.view((*shape, *sample_indices.shape[-1:])) # int: [... x num_sample]
@@ -70,7 +70,6 @@ class AxisAlign(TorchTransformerMixin):
70
70
  raise ValueError(f"Invalid sort method {self.sort_method}.")
71
71
 
72
72
  self.R = self.R[torch.argsort(sort_metric, dim=0, descending=True)]
73
- self.is_fitted = True
74
73
  return self
75
74
 
76
75
  def transform(self, X: torch.Tensor, normalize: bool = True, hard: bool = False) -> torch.Tensor:
@@ -36,8 +36,6 @@ class TorchTransformerMixin:
36
36
  >>> transformer.fit_transform(X)
37
37
  array([1, 1, 1])
38
38
  """
39
- def __init__(self):
40
- self.is_fitted: bool = False
41
39
 
42
40
  @abstractmethod
43
41
  def fit(self, X: torch.Tensor, **fit_kwargs: Any) -> "TorchTransformerMixin":
@@ -1,5 +1,4 @@
1
- import logging
2
- from typing import Any, Callable, Dict, Literal, Union
1
+ from typing import Any, Callable, Dict, Union
3
2
 
4
3
  import numpy as np
5
4
  import torch
@@ -13,7 +12,8 @@ from .common import (
13
12
  quantile_normalize,
14
13
  )
15
14
  from .distance_utils import (
16
- DistanceOptions,
15
+ AffinityOptions,
16
+ AFFINITY_TO_DISTANCE,
17
17
  to_euclidean,
18
18
  affinity_from_features,
19
19
  )
@@ -27,7 +27,7 @@ def extrapolate_knn(
27
27
  anchor_features: torch.Tensor, # [n x d]
28
28
  anchor_output: torch.Tensor, # [n x d']
29
29
  extrapolation_features: torch.Tensor, # [m x d]
30
- distance: DistanceOptions,
30
+ affinity_type: AffinityOptions,
31
31
  knn: int = 10, # k
32
32
  affinity_focal_gamma: float = 1.0,
33
33
  chunk_size: int = 8192,
@@ -41,7 +41,7 @@ def extrapolate_knn(
41
41
  anchor_output (torch.Tensor): output from subgraph, shape (num_sample, D)
42
42
  extrapolation_features (torch.Tensor): features from existing nodes, shape (new_num_samples, n_features)
43
43
  knn (int): number of KNN to propagate eige nvectors
44
- distance (str): distance metric, 'cosine' (default) or 'euclidean', 'rbf'
44
+ affinity_type (str): distance metric, 'cosine' (default) or 'euclidean', 'rbf'
45
45
  chunk_size (int): chunk size for matrix multiplication
46
46
  device (str): device to use for computation, if None, will not change device
47
47
  Returns:
@@ -66,7 +66,7 @@ def extrapolate_knn(
66
66
  for _v in torch.chunk(extrapolation_features, n_chunks, dim=0):
67
67
  _v = _v.to(device) # [_m x d]
68
68
 
69
- _A = affinity_from_features(anchor_features, _v, affinity_focal_gamma, distance).mT # [_m x n]
69
+ _A = affinity_from_features(anchor_features, _v, affinity_focal_gamma, affinity_type).mT # [_m x n]
70
70
  if knn is not None:
71
71
  _A, indices = _A.topk(k=knn, dim=-1, largest=True) # [_m x k], [_m x k]
72
72
  _anchor_output = anchor_output[indices] # [_m x k x d]
@@ -90,7 +90,7 @@ def extrapolate_knn_with_subsampling(
90
90
  full_output: torch.Tensor, # [n x d']
91
91
  extrapolation_features: torch.Tensor, # [m x d]
92
92
  sample_config: SampleConfig,
93
- distance: DistanceOptions,
93
+ affinity_type: AffinityOptions,
94
94
  knn: int = 10, # k
95
95
  affinity_focal_gamma: float = 1.0,
96
96
  chunk_size: int = 8192,
@@ -122,7 +122,7 @@ def extrapolate_knn_with_subsampling(
122
122
  # sample subgraph
123
123
  anchor_indices = subsample_features(
124
124
  features=full_features,
125
- disttype=distance,
125
+ distance_type=AFFINITY_TO_DISTANCE[affinity_type],
126
126
  config=sample_config,
127
127
  )
128
128
 
@@ -135,7 +135,7 @@ def extrapolate_knn_with_subsampling(
135
135
  anchor_features,
136
136
  anchor_output,
137
137
  extrapolation_features,
138
- distance,
138
+ affinity_type,
139
139
  knn=knn,
140
140
  affinity_focal_gamma=affinity_focal_gamma,
141
141
  chunk_size=chunk_size,
@@ -148,7 +148,7 @@ def extrapolate_knn_with_subsampling(
148
148
  def _rgb_with_dimensionality_reduction(
149
149
  features: torch.Tensor,
150
150
  num_sample: int,
151
- disttype: Literal["cosine", "euclidean"],
151
+ affinity_type: AffinityOptions,
152
152
  rgb_func: Callable[[torch.Tensor, float], torch.Tensor],
153
153
  q: float,
154
154
  knn: int,
@@ -162,26 +162,26 @@ def _rgb_with_dimensionality_reduction(
162
162
  if True:
163
163
  _subgraph_indices = subsample_features(
164
164
  features=features,
165
- disttype=disttype,
165
+ distance_type=AFFINITY_TO_DISTANCE[affinity_type],
166
166
  config=SampleConfig(method="fps"),
167
167
  )
168
168
  features = extrapolate_knn(
169
169
  anchor_features=features[_subgraph_indices],
170
170
  anchor_output=features[_subgraph_indices],
171
171
  extrapolation_features=features,
172
- distance=disttype,
172
+ affinity_type=affinity_type,
173
173
  )
174
174
 
175
175
  subgraph_indices = subsample_features(
176
176
  features=features,
177
- disttype=disttype,
177
+ distance_type=AFFINITY_TO_DISTANCE[affinity_type],
178
178
  config=SampleConfig(method="fps", num_sample=num_sample),
179
179
  )
180
180
 
181
181
  _inp = features[subgraph_indices].numpy(force=True)
182
182
  _subgraph_embed = torch.tensor(reduction(
183
183
  n_components=reduction_dim,
184
- metric=disttype,
184
+ metric=AFFINITY_TO_DISTANCE[affinity_type],
185
185
  random_state=seed,
186
186
  **reduction_kwargs
187
187
  ).fit_transform(_inp), device=features.device, dtype=features.dtype)
@@ -190,7 +190,7 @@ def _rgb_with_dimensionality_reduction(
190
190
  features[subgraph_indices],
191
191
  _subgraph_embed,
192
192
  features,
193
- disttype,
193
+ affinity_type,
194
194
  knn=knn,
195
195
  device=device,
196
196
  move_output_to_cpu=True
@@ -201,7 +201,7 @@ def _rgb_with_dimensionality_reduction(
201
201
  def rgb_from_tsne_2d(
202
202
  features: torch.Tensor,
203
203
  num_sample: int = 1000,
204
- disttype: Literal["cosine", "euclidean"] = "cosine",
204
+ affinity_type: AffinityOptions = "cosine",
205
205
  perplexity: int = 150,
206
206
  q: float = 0.95,
207
207
  knn: int = 10,
@@ -220,16 +220,12 @@ def rgb_from_tsne_2d(
220
220
  "sklearn import failed, please install `pip install scikit-learn`"
221
221
  )
222
222
  num_sample = min(num_sample, features.shape[0])
223
- if perplexity > num_sample // 2:
224
- logging.warning(
225
- f"perplexity is larger than num_sample, set perplexity to {num_sample // 2}"
226
- )
227
- perplexity = num_sample // 2
223
+ perplexity = min(perplexity, num_sample // 2)
228
224
 
229
225
  rgb = _rgb_with_dimensionality_reduction(
230
226
  features=features,
231
227
  num_sample=num_sample,
232
- disttype=disttype,
228
+ affinity_type=affinity_type,
233
229
  rgb_func=rgb_from_2d_colormap,
234
230
  q=q,
235
231
  knn=knn,
@@ -245,7 +241,7 @@ def rgb_from_tsne_2d(
245
241
  def rgb_from_tsne_3d(
246
242
  features: torch.Tensor,
247
243
  num_sample: int = 1000,
248
- disttype: Literal["cosine", "euclidean"] = "cosine",
244
+ affinity_type: AffinityOptions = "cosine",
249
245
  perplexity: int = 150,
250
246
  q: float = 0.95,
251
247
  knn: int = 10,
@@ -264,16 +260,12 @@ def rgb_from_tsne_3d(
264
260
  "sklearn import failed, please install `pip install scikit-learn`"
265
261
  )
266
262
  num_sample = min(num_sample, features.shape[0])
267
- if perplexity > num_sample // 2:
268
- logging.warning(
269
- f"perplexity is larger than num_sample, set perplexity to {num_sample // 2}"
270
- )
271
- perplexity = num_sample // 2
263
+ perplexity = min(perplexity, num_sample // 2)
272
264
 
273
265
  rgb = _rgb_with_dimensionality_reduction(
274
266
  features=features,
275
267
  num_sample=num_sample,
276
- disttype=disttype,
268
+ affinity_type=affinity_type,
277
269
  rgb_func=rgb_from_3d_rgb_cube,
278
270
  q=q,
279
271
  knn=knn,
@@ -289,7 +281,7 @@ def rgb_from_tsne_3d(
289
281
  def rgb_from_euclidean_tsne_3d(
290
282
  features: torch.Tensor,
291
283
  num_sample: int = 1000,
292
- disttype: Literal["cosine", "euclidean"] = "cosine",
284
+ affinity_type: AffinityOptions = "cosine",
293
285
  perplexity: int = 150,
294
286
  q: float = 0.95,
295
287
  knn: int = 10,
@@ -308,19 +300,15 @@ def rgb_from_euclidean_tsne_3d(
308
300
  "sklearn import failed, please install `pip install scikit-learn`"
309
301
  )
310
302
  num_sample = min(num_sample, features.shape[0])
311
- if perplexity > num_sample // 2:
312
- logging.warning(
313
- f"perplexity is larger than num_sample, set perplexity to {num_sample // 2}"
314
- )
315
- perplexity = num_sample // 2
303
+ perplexity = min(perplexity, num_sample // 2)
316
304
 
317
305
  def rgb_func(X_3d: torch.Tensor, q: float) -> torch.Tensor:
318
- return rgb_from_3d_rgb_cube(to_euclidean(X_3d, disttype), q=q)
306
+ return rgb_from_3d_rgb_cube(to_euclidean(X_3d, AFFINITY_TO_DISTANCE[affinity_type]), q=q)
319
307
 
320
308
  rgb = _rgb_with_dimensionality_reduction(
321
309
  features=features,
322
310
  num_sample=num_sample,
323
- disttype="cosine",
311
+ affinity_type=affinity_type,
324
312
  rgb_func=rgb_func,
325
313
  q=q,
326
314
  knn=knn,
@@ -336,7 +324,7 @@ def rgb_from_euclidean_tsne_3d(
336
324
  def rgb_from_umap_2d(
337
325
  features: torch.Tensor,
338
326
  num_sample: int = 1000,
339
- disttype: Literal["cosine", "euclidean"] = "cosine",
327
+ affinity_type: AffinityOptions = "cosine",
340
328
  n_neighbors: int = 150,
341
329
  min_dist: float = 0.1,
342
330
  q: float = 0.95,
@@ -357,7 +345,7 @@ def rgb_from_umap_2d(
357
345
  rgb = _rgb_with_dimensionality_reduction(
358
346
  features=features,
359
347
  num_sample=num_sample,
360
- disttype=disttype,
348
+ affinity_type=affinity_type,
361
349
  rgb_func=rgb_from_2d_colormap,
362
350
  q=q,
363
351
  knn=knn,
@@ -374,7 +362,7 @@ def rgb_from_umap_2d(
374
362
  def rgb_from_umap_sphere(
375
363
  features: torch.Tensor,
376
364
  num_sample: int = 1000,
377
- disttype: Literal["cosine", "euclidean"] = "cosine",
365
+ affinity_type: AffinityOptions = "cosine",
378
366
  n_neighbors: int = 150,
379
367
  min_dist: float = 0.1,
380
368
  q: float = 0.95,
@@ -402,7 +390,7 @@ def rgb_from_umap_sphere(
402
390
  rgb = _rgb_with_dimensionality_reduction(
403
391
  features=features,
404
392
  num_sample=num_sample,
405
- disttype=disttype,
393
+ affinity_type=affinity_type,
406
394
  rgb_func=rgb_func,
407
395
  q=q,
408
396
  knn=knn,
@@ -420,7 +408,7 @@ def rgb_from_umap_sphere(
420
408
  def rgb_from_umap_3d(
421
409
  features: torch.Tensor,
422
410
  num_sample: int = 1000,
423
- disttype: Literal["cosine", "euclidean"] = "cosine",
411
+ affinity_type: AffinityOptions = "cosine",
424
412
  n_neighbors: int = 150,
425
413
  min_dist: float = 0.1,
426
414
  q: float = 0.95,
@@ -441,7 +429,7 @@ def rgb_from_umap_3d(
441
429
  rgb = _rgb_with_dimensionality_reduction(
442
430
  features=features,
443
431
  num_sample=num_sample,
444
- disttype=disttype,
432
+ affinity_type=affinity_type,
445
433
  rgb_func=rgb_from_3d_rgb_cube,
446
434
  q=q,
447
435
  knn=knn,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: nystrom_ncut
3
- Version: 0.2.2
3
+ Version: 0.3.0
4
4
  Summary: Normalized Cut and Nyström Approximation
5
5
  Author-email: Huzheng Yang <huze.yann@gmail.com>, Wentinn Liao <wentinn.liao@gmail.com>
6
6
  Project-URL: Documentation, https://github.com/JophiArcana/Nystrom-NCUT/
@@ -0,0 +1,18 @@
1
+ __init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ nystrom_ncut/__init__.py,sha256=tKq9-2QRNFetckHY77qAaKEMjMCYTYcorS2f74aNtvk,540
3
+ nystrom_ncut/common.py,sha256=eie19AHTMk6AGTxNnYq1UcFkHJVimeywAUYryXwaiHk,2428
4
+ nystrom_ncut/distance_utils.py,sha256=pJA8NcIKyS7-YDpRGOkc7mwBQQEsYVemdkHiTjyU4n8,4300
5
+ nystrom_ncut/sampling_utils.py,sha256=6lP8F6gftl4mgkavPsD7Vuk4erj4RtgILPhcj3YqLXk,4840
6
+ nystrom_ncut/visualize_utils.py,sha256=Sfi_kKpvFFzBFoJnbo-pQpH2jhs-A6tH64SV_WGoq58,22740
7
+ nystrom_ncut/nystrom/__init__.py,sha256=1aUXK87g4cXRXqNt6XkZsfyauw1-yv3sv0NmdmkWo-8,42
8
+ nystrom_ncut/nystrom/distance_realization.py,sha256=RTI1_Q8fCUGAPSbXaVuNA-2B-11CEAfy2CwKWPJj6xQ,5830
9
+ nystrom_ncut/nystrom/normalized_cut.py,sha256=jB_QALMY3l5CFfZPsrOFpEaquTrJP17muTrDZXxzUA8,7177
10
+ nystrom_ncut/nystrom/nystrom_utils.py,sha256=hksDO8uuAb9xKoA1ZafGwXDlQN_gZJn_qHscaSoO8JE,14120
11
+ nystrom_ncut/transformer/__init__.py,sha256=jjXjcNp3LrxeF6mqG9VY5k3asrqaY6bXzJz6wTpH78Q,105
12
+ nystrom_ncut/transformer/axis_align.py,sha256=6LTR-syJ-f4pcbnMexFmFNn1QADDhH5gka6979YBRrI,3549
13
+ nystrom_ncut/transformer/transformer_mixin.py,sha256=YAjrDWTL5Hjnk9J2OsoxvtwT2N0u8IdgMSx0rRFmZzE,1653
14
+ nystrom_ncut-0.3.0.dist-info/LICENSE,sha256=2bm9uFabQZ3Ykb_SaSU_uUbAj2-htc6WJQmS_65qD00,1073
15
+ nystrom_ncut-0.3.0.dist-info/METADATA,sha256=lhxicufu5Eo9HQsUiS_K-CzocemOeNravAaIXeCtriM,6058
16
+ nystrom_ncut-0.3.0.dist-info/WHEEL,sha256=52BFRY2Up02UkjOa29eZOS2VxUrpPORXg1pkohGGUS8,91
17
+ nystrom_ncut-0.3.0.dist-info/top_level.txt,sha256=gM8IWWHYysIRTCvCTcdS4RShOyl9pxpylgSwPUZR2XM,22
18
+ nystrom_ncut-0.3.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.8.0)
2
+ Generator: setuptools (76.0.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,18 +0,0 @@
1
- __init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- nystrom_ncut/__init__.py,sha256=tKq9-2QRNFetckHY77qAaKEMjMCYTYcorS2f74aNtvk,540
3
- nystrom_ncut/common.py,sha256=_PGJoImSk_Fb_5Ri-e_IsFoCcSfbGS8CxYUUHVoNM50,2036
4
- nystrom_ncut/distance_utils.py,sha256=p-pYdpRrJsIhzxM_IxUqja7N8okngx52WGXD9pu_Aec,3129
5
- nystrom_ncut/sampling_utils.py,sha256=oMmhFcd_N_D15Ht7F0rCGPSgLeitJszAKMD3ICKwHNU,3105
6
- nystrom_ncut/visualize_utils.py,sha256=d3VXjzJPZPPyUMg_b8hKLQoBaRWvutu6u7l36S2gmIM,23007
7
- nystrom_ncut/nystrom/__init__.py,sha256=lAoO00i4FG5xqGKDO_OYcSvO4qPK64x_X_hDNBvuLUc,105
8
- nystrom_ncut/nystrom/distance_realization.py,sha256=InajllGtRVnLVlZoipZNbHFTGHaTs3zxizKe3kI2Los,5815
9
- nystrom_ncut/nystrom/normalized_cut.py,sha256=2ocwc4U3A6GGFs0cuL0DO1yNvt59SJ3uDtj00U0foPM,5906
10
- nystrom_ncut/nystrom/nystrom_utils.py,sha256=MMRMq7N_JYM2whSHIhCoPA-SQ28wb9hC8u2CZNmRRN8,12836
11
- nystrom_ncut/transformer/__init__.py,sha256=jjXjcNp3LrxeF6mqG9VY5k3asrqaY6bXzJz6wTpH78Q,105
12
- nystrom_ncut/transformer/axis_align.py,sha256=8PYtSTChHDTrh5SYdhl1ALsUUPJHd9ojQRM1e6KTbHc,3579
13
- nystrom_ncut/transformer/transformer_mixin.py,sha256=kVjODpIHB6noC7yGY-QPf67Ep58o53wrEMKSNjhhChI,1714
14
- nystrom_ncut-0.2.2.dist-info/LICENSE,sha256=2bm9uFabQZ3Ykb_SaSU_uUbAj2-htc6WJQmS_65qD00,1073
15
- nystrom_ncut-0.2.2.dist-info/METADATA,sha256=RGX2HMT2uF9bUB_4qecpX87bY6N1gg6-QKXvhJVnzIo,6058
16
- nystrom_ncut-0.2.2.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
17
- nystrom_ncut-0.2.2.dist-info/top_level.txt,sha256=gM8IWWHYysIRTCvCTcdS4RShOyl9pxpylgSwPUZR2XM,22
18
- nystrom_ncut-0.2.2.dist-info/RECORD,,