nystrom-ncut 0.2.2__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -25,15 +25,15 @@ EigSolverOptions = Literal["svd_lowrank", "lobpcg", "svd", "eigh"]
25
25
 
26
26
  class OnlineKernel:
27
27
  @abstractmethod
28
- def fit(self, features: torch.Tensor) -> "OnlineKernel": # [n x d]
28
+ def fit(self, features: torch.Tensor) -> "OnlineKernel": # [... x n x d]
29
29
  """"""
30
30
 
31
31
  @abstractmethod
32
- def update(self, features: torch.Tensor) -> torch.Tensor: # [m x d] -> [m x n]
32
+ def update(self, features: torch.Tensor) -> torch.Tensor: # [... x m x d] -> [... x m x n]
33
33
  """"""
34
34
 
35
35
  @abstractmethod
36
- def transform(self, features: torch.Tensor = None) -> torch.Tensor: # [m x d] -> [m x n]
36
+ def transform(self, features: torch.Tensor = None) -> torch.Tensor: # [... x m x d] -> [... x m x n]
37
37
  """"""
38
38
 
39
39
 
@@ -54,20 +54,21 @@ class OnlineNystrom(TorchTransformerMixin):
54
54
  self.n_components: int = n_components
55
55
  self.kernel: OnlineKernel = kernel
56
56
  self.eig_solver: EigSolverOptions = eig_solver
57
+ self.shape: torch.Size = None # ...
57
58
 
58
59
  self.chunk_size = chunk_size
59
60
 
60
61
  # Anchor matrices
61
- self.anchor_features: torch.Tensor = None # [n x d]
62
- self.A: torch.Tensor = None # [n x n]
63
- self.Ahinv: torch.Tensor = None # [n x n]
64
- self.Ahinv_UL: torch.Tensor = None # [n x indirect_pca_dim]
65
- self.Ahinv_VT: torch.Tensor = None # [indirect_pca_dim x n]
62
+ self.anchor_features: torch.Tensor = None # [... x n x d]
63
+ self.A: torch.Tensor = None # [... x n x n]
64
+ self.Ahinv: torch.Tensor = None # [... x n x n]
65
+ self.Ahinv_UL: torch.Tensor = None # [... x n x indirect_pca_dim]
66
+ self.Ahinv_VT: torch.Tensor = None # [... x indirect_pca_dim x n]
66
67
 
67
68
  # Updated matrices
68
- self.S: torch.Tensor = None # [n x n]
69
- self.transform_matrix: torch.Tensor = None # [n x n_components]
70
- self.eigenvalues_: torch.Tensor = None # [n]
69
+ self.S: torch.Tensor = None # [... x n x n]
70
+ self.transform_matrix: torch.Tensor = None # [... x n x n_components]
71
+ self.eigenvalues_: torch.Tensor = None # [... x n]
71
72
 
72
73
  def _update_to_kernel(self, d: int) -> Tuple[torch.Tensor, torch.Tensor]:
73
74
  self.A = self.S = self.kernel.transform()
@@ -75,10 +76,10 @@ class OnlineNystrom(TorchTransformerMixin):
75
76
  self.A,
76
77
  num_eig=d + 1, # d * (d + 3) // 2 + 1,
77
78
  eig_solver=self.eig_solver,
78
- ) # [n x (? + 1)], [? + 1]
79
- self.Ahinv_UL = U * (L ** -0.5) # [n x (? + 1)]
80
- self.Ahinv_VT = U.mT # [(? + 1) x n]
81
- self.Ahinv = self.Ahinv_UL @ self.Ahinv_VT # [n x n]
79
+ ) # [... x n x (? + 1)], [... x (? + 1)]
80
+ self.Ahinv_UL = U * (L[..., None, :] ** -0.5) # [... x n x (? + 1)]
81
+ self.Ahinv_VT = U.mT # [... x (? + 1) x n]
82
+ self.Ahinv = self.Ahinv_UL @ self.Ahinv_VT # [... x n x n]
82
83
  return U, L
83
84
 
84
85
  def fit(self, features: torch.Tensor) -> "OnlineNystrom":
@@ -89,65 +90,63 @@ class OnlineNystrom(TorchTransformerMixin):
89
90
  self.anchor_features = features
90
91
 
91
92
  self.kernel.fit(self.anchor_features)
92
- U, L = self._update_to_kernel(features.shape[-1]) # [n x (d + 1)], [d + 1]
93
+ U, L = self._update_to_kernel(features.shape[-1]) # [... x n x (d + 1)], [... x (d + 1)]
93
94
 
94
- self.transform_matrix = (U / L)[:, :self.n_components] # [n x n_components]
95
- self.eigenvalues_ = L[:self.n_components] # [n_components]
96
- self.is_fitted = True
97
- return U[:, :self.n_components] # [n x n_components]
95
+ self.transform_matrix = (U / L[..., None, :])[..., :, :self.n_components] # [... x n x n_components]
96
+ self.eigenvalues_ = L[..., :self.n_components] # [... x n_components]
97
+ return U[..., :, :self.n_components] # [... x n x n_components]
98
98
 
99
99
  def update(self, features: torch.Tensor) -> torch.Tensor:
100
100
  d = features.shape[-1]
101
- n_chunks = ceildiv(len(features), self.chunk_size)
101
+ n_chunks = ceildiv(features.shape[-2], self.chunk_size)
102
102
  if n_chunks > 1:
103
103
  """ Chunked version """
104
- chunks = torch.chunk(features, n_chunks, dim=0)
104
+ chunks = torch.chunk(features, n_chunks, dim=-2)
105
105
  for chunk in chunks:
106
106
  self.kernel.update(chunk)
107
107
  self._update_to_kernel(d)
108
108
 
109
- compressed_BBT = 0.0 # [(? + 1) x (? + 1))]
109
+ compressed_BBT = 0.0 # [... x (? + 1) x (? + 1))]
110
110
  for chunk in chunks:
111
- _B = self.kernel.transform(chunk).mT # [n x _m]
112
- _compressed_B = self.Ahinv_VT @ _B # [(? + 1) x _m]
113
- compressed_BBT = compressed_BBT + _compressed_B @ _compressed_B.mT # [(? + 1) x (? + 1)]
114
- self.S = self.S + self.Ahinv_UL @ compressed_BBT @ self.Ahinv_UL.mT # [n x n]
115
- US, self.eigenvalues_ = solve_eig(self.S, self.n_components, self.eig_solver) # [n x n_components], [n_components]
116
- self.transform_matrix = self.Ahinv @ US * (self.eigenvalues_ ** -0.5) # [n x n_components]
111
+ _B = self.kernel.transform(chunk).mT # [... x n x _m]
112
+ _compressed_B = self.Ahinv_VT @ _B # [... x (? + 1) x _m]
113
+ _compressed_B = torch.nan_to_num(_compressed_B, nan=0.0)
114
+ compressed_BBT = compressed_BBT + _compressed_B @ _compressed_B.mT # [... x (? + 1) x (? + 1)]
115
+ self.S = self.S + self.Ahinv_UL @ compressed_BBT @ self.Ahinv_UL.mT # [... x n x n]
116
+ US, self.eigenvalues_ = solve_eig(self.S, self.n_components, self.eig_solver) # [... x n x n_components], [... x n_components]
117
+ self.transform_matrix = self.Ahinv @ US * (self.eigenvalues_[..., None, :] ** -0.5) # [... x n x n_components]
117
118
 
118
119
  VS = []
119
120
  for chunk in chunks:
120
- VS.append(self.kernel.transform(chunk) @ self.transform_matrix) # [_m x n_components]
121
- VS = torch.cat(VS, dim=0)
122
- return VS # [m x n_components]
121
+ VS.append(self.kernel.transform(chunk) @ self.transform_matrix) # [... x _m x n_components]
122
+ VS = torch.cat(VS, dim=-2)
123
+ return VS # [... x m x n_components]
123
124
  else:
124
125
  """ Unchunked version """
125
- B = self.kernel.update(features).mT # [n x m]
126
+ B = self.kernel.update(features).mT # [... x n x m]
126
127
  self._update_to_kernel(d)
127
- compressed_B = self.Ahinv_VT @ B # [indirect_pca_dim x m]
128
+ compressed_B = self.Ahinv_VT @ B # [... x (? + 1) x m]
129
+ compressed_B = torch.nan_to_num(compressed_B, nan=0.0)
128
130
 
129
- self.S = self.S + self.Ahinv_UL @ (compressed_B @ compressed_B.mT) @ self.Ahinv_UL.mT # [n x n]
130
- US, self.eigenvalues_ = solve_eig(self.S, self.n_components, self.eig_solver) # [n x n_components], [n_components]
131
- self.transform_matrix = self.Ahinv @ US * (self.eigenvalues_ ** -0.5) # [n x n_components]
131
+ self.S = self.S + self.Ahinv_UL @ (compressed_B @ compressed_B.mT) @ self.Ahinv_UL.mT # [... x n x n]
132
+ US, self.eigenvalues_ = solve_eig(self.S, self.n_components, self.eig_solver) # [... x n x n_components], [... x n_components]
133
+ self.transform_matrix = self.Ahinv @ US * (self.eigenvalues_[..., None, :] ** -0.5) # [... x n x n_components]
132
134
 
133
- return B.mT @ self.transform_matrix # [m x n_components]
135
+ return B.mT @ self.transform_matrix # [... x m x n_components]
134
136
 
135
- def transform(self, features: torch.Tensor = None) -> torch.Tensor:
136
- if features is None:
137
- VS = self.A @ self.transform_matrix # [n x n_components]
137
+ def transform(self, features: torch.Tensor) -> torch.Tensor:
138
+ n_chunks = ceildiv(features.shape[-2], self.chunk_size)
139
+ if n_chunks > 1:
140
+ """ Chunked version """
141
+ chunks = torch.chunk(features, n_chunks, dim=-2)
142
+ VS = []
143
+ for chunk in chunks:
144
+ VS.append(self.kernel.transform(chunk) @ self.transform_matrix) # [... x _m x n_components]
145
+ VS = torch.cat(VS, dim=-2)
138
146
  else:
139
- n_chunks = ceildiv(len(features), self.chunk_size)
140
- if n_chunks > 1:
141
- """ Chunked version """
142
- chunks = torch.chunk(features, n_chunks, dim=0)
143
- VS = []
144
- for chunk in chunks:
145
- VS.append(self.kernel.transform(chunk) @ self.transform_matrix) # [_m x n_components]
146
- VS = torch.cat(VS, dim=0)
147
- else:
148
- """ Unchunked version """
149
- VS = self.kernel.transform(features) @ self.transform_matrix # [m x n_components]
150
- return VS # [m x n_components]
147
+ """ Unchunked version """
148
+ VS = self.kernel.transform(features) @ self.transform_matrix # [... x m x n_components]
149
+ return VS # [... x m x n_components]
151
150
 
152
151
 
153
152
  class OnlineNystromSubsampleFit(OnlineNystrom):
@@ -155,7 +154,7 @@ class OnlineNystromSubsampleFit(OnlineNystrom):
155
154
  self,
156
155
  n_components: int,
157
156
  kernel: OnlineKernel,
158
- distance: DistanceOptions,
157
+ distance_type: DistanceOptions,
159
158
  sample_config: SampleConfig,
160
159
  eig_solver: EigSolverOptions = "svd_lowrank",
161
160
  chunk_size: int = 8192,
@@ -167,7 +166,7 @@ class OnlineNystromSubsampleFit(OnlineNystrom):
167
166
  eig_solver=eig_solver,
168
167
  chunk_size=chunk_size,
169
168
  )
170
- self.distance: DistanceOptions = distance
169
+ self.distance_type: DistanceOptions = distance_type
171
170
  self.sample_config: SampleConfig = sample_config
172
171
  self.sample_config._ncut_obj = copy.deepcopy(self)
173
172
  self.anchor_indices: torch.Tensor = None
@@ -177,7 +176,7 @@ class OnlineNystromSubsampleFit(OnlineNystrom):
177
176
  features: torch.Tensor,
178
177
  precomputed_sampled_indices: torch.Tensor,
179
178
  ) -> Tuple[torch.Tensor, torch.Tensor]:
180
- _n = features.shape[0]
179
+ _n = features.shape[-2]
181
180
  if self.sample_config.num_sample >= _n:
182
181
  logging.info(
183
182
  f"NCUT nystrom num_sample is larger than number of input samples, nyström approximation is not needed, setting num_sample={_n}"
@@ -189,16 +188,17 @@ class OnlineNystromSubsampleFit(OnlineNystrom):
189
188
  else:
190
189
  self.anchor_indices = subsample_features(
191
190
  features=features,
192
- disttype=self.distance,
191
+ distance_type=self.distance_type,
193
192
  config=self.sample_config,
194
193
  )
195
- sampled_features = features[self.anchor_indices]
194
+ sampled_features = torch.gather(features, -2, self.anchor_indices[..., None].expand([-1] * self.anchor_indices.ndim + [features.shape[-1]]))
196
195
  OnlineNystrom.fit(self, sampled_features)
197
196
 
198
- _n_not_sampled = _n - len(sampled_features)
197
+ _n_not_sampled = _n - self.anchor_indices.shape[-1]
199
198
  if _n_not_sampled > 0:
200
- unsampled_indices = torch.full((_n,), True, device=features.device).scatter_(0, self.anchor_indices, False)
201
- unsampled_features = features[unsampled_indices]
199
+ unsampled_mask = torch.full(features.shape[:-1], True, device=features.device).scatter_(-1, self.anchor_indices, False)
200
+ unsampled_indices = torch.where(unsampled_mask)[-1].view((*features.shape[:-2], -1))
201
+ unsampled_features = torch.gather(features, -2, unsampled_indices[..., None].expand([-1] * unsampled_indices.ndim + [features.shape[-1]]))
202
202
  V_unsampled = OnlineNystrom.update(self, unsampled_features)
203
203
  else:
204
204
  unsampled_indices = V_unsampled = None
@@ -236,12 +236,12 @@ class OnlineNystromSubsampleFit(OnlineNystrom):
236
236
  (torch.Tensor): eigen_values, sorted in descending order, shape (num_eig,)
237
237
  """
238
238
  unsampled_indices, V_unsampled = OnlineNystromSubsampleFit._fit_helper(self, features, precomputed_sampled_indices)
239
- V_sampled = OnlineNystrom.transform(self)
239
+ V_sampled = OnlineNystrom.transform(self, self.anchor_features)
240
240
 
241
241
  if unsampled_indices is not None:
242
- V = torch.zeros((len(unsampled_indices), self.n_components), device=features.device)
243
- V[~unsampled_indices] = V_sampled
244
- V[unsampled_indices] = V_unsampled
242
+ V = torch.zeros((*features.shape[:-1], self.n_components), device=features.device)
243
+ for (indices, _V) in [(self.anchor_indices, V_sampled), (unsampled_indices, V_unsampled)]:
244
+ V.scatter_(-2, indices[..., None].expand([-1] * indices.ndim + [self.n_components]), _V)
245
245
  else:
246
246
  V = V_sampled
247
247
  return V
@@ -264,12 +264,16 @@ def solve_eig(
264
264
  (torch.Tensor): eigenvectors corresponding to the eigenvalues, shape (n_samples, num_eig)
265
265
  (torch.Tensor): eigenvalues of the eigenvectors, sorted in descending order
266
266
  """
267
- A = A + eig_value_buffer * torch.eye(A.shape[0], device=A.device)
267
+ shape: torch.Size = A.shape[:-2]
268
+ A = A.view((-1, *A.shape[-2:]))
269
+ bsz: int = A.shape[0]
270
+
271
+ A = A + eig_value_buffer * torch.eye(A.shape[-1], device=A.device)
268
272
 
269
273
  # compute eigenvectors
270
274
  if eig_solver == "svd_lowrank": # default
271
275
  # only top q eigenvectors, fastest
272
- eigen_vector, eigen_value, _ = torch.svd_lowrank(A, q=num_eig)
276
+ eigen_vector, eigen_value, _ = torch.svd_lowrank(A, q=num_eig) # complex: [(...) x N x D], [(...) x D]
273
277
  elif eig_solver == "lobpcg":
274
278
  # only top k eigenvectors, fast
275
279
  eigen_value, eigen_vector = torch.lobpcg(A, k=num_eig)
@@ -286,11 +290,15 @@ def solve_eig(
286
290
  eigen_value = eigen_value - eig_value_buffer
287
291
 
288
292
  # sort eigenvectors by eigenvalues, take top (descending order)
289
- indices = torch.topk(eigen_value.abs(), k=num_eig, dim=0).indices
290
- eigen_value, eigen_vector = eigen_value[indices], eigen_vector[:, indices]
293
+ indices = torch.topk(eigen_value.abs(), k=num_eig, dim=-1).indices # int: [(...) x S]
294
+ eigen_value = eigen_value[torch.arange(bsz)[:, None], indices] # complex: [(...) x S]
295
+ eigen_vector = eigen_vector[torch.arange(bsz)[:, None], :, indices].mT # complex: [(...) x N x S]
291
296
 
292
297
  # correct the random rotation (flipping sign) of eigenvectors
293
- sign = torch.sum(eigen_vector.real, dim=0).sign()
298
+ sign = torch.sign(torch.sum(eigen_vector.real, dim=-2, keepdim=True)) # float: [(...) x 1 x S]
294
299
  sign[sign == 0] = 1.0
295
300
  eigen_vector = eigen_vector * sign
301
+
302
+ eigen_value = eigen_value.view((*shape, *eigen_value.shape[-1:])) # complex: [... x S]
303
+ eigen_vector = eigen_vector.view((*shape, *eigen_vector.shape[-2:])) # complex: [... x N x S]
296
304
  return eigen_vector, eigen_value
@@ -1,17 +1,22 @@
1
- import logging
2
1
  from dataclasses import dataclass
3
2
  from typing import Literal
4
3
 
5
4
  import torch
6
5
  from pytorch3d.ops import sample_farthest_points
7
6
 
7
+ from .common import (
8
+ default_device,
9
+ )
8
10
  from .distance_utils import (
9
11
  DistanceOptions,
10
12
  to_euclidean,
11
13
  )
14
+ from .transformer import (
15
+ TorchTransformerMixin,
16
+ )
12
17
 
13
18
 
14
- SampleOptions = Literal["random", "fps", "fps_recursive"]
19
+ SampleOptions = Literal["full", "random", "fps", "fps_recursive"]
15
20
 
16
21
 
17
22
  @dataclass
@@ -20,69 +25,77 @@ class SampleConfig:
20
25
  num_sample: int = 10000
21
26
  fps_dim: int = 12
22
27
  n_iter: int = None
23
- _ncut_obj: object = None
28
+ _ncut_obj: TorchTransformerMixin = None
24
29
 
25
30
 
26
31
  @torch.no_grad()
27
32
  def subsample_features(
28
33
  features: torch.Tensor,
29
- disttype: DistanceOptions,
34
+ distance_type: DistanceOptions,
30
35
  config: SampleConfig,
31
- max_draw: int = 1000000,
32
36
  ):
33
- features = features.detach()
34
- if config.num_sample >= features.shape[0]:
35
- # if too many samples, use all samples and bypass Nystrom-like approximation
36
- logging.info(
37
- "num_sample is larger than total, bypass Nystrom-like approximation"
38
- )
39
- sampled_indices = torch.arange(features.shape[0])
40
- else:
41
- # sample subgraph
42
- if config.method == "fps": # default
43
- features = to_euclidean(features, disttype)
44
- if config.num_sample > max_draw:
45
- logging.warning(
46
- f"num_sample is larger than max_draw, apply farthest point sampling on random sampled {max_draw} samples"
47
- )
48
- draw_indices = torch.randperm(features.shape[0])[:max_draw]
49
- sampled_indices = fpsample(features[draw_indices], config)
50
- sampled_indices = draw_indices[sampled_indices]
51
- else:
52
- sampled_indices = fpsample(features, config)
53
-
54
- elif config.method == "random": # not recommended
55
- sampled_indices = torch.randperm(features.shape[0])[:config.num_sample]
56
-
57
- elif config.method == "fps_recursive":
58
- features = to_euclidean(features, disttype)
59
- sampled_indices = subsample_features(
60
- features=features,
61
- disttype=disttype,
62
- config=SampleConfig(method="fps", num_sample=config.num_sample, fps_dim=config.fps_dim)
63
- )
64
- nc = config._ncut_obj
65
- for _ in range(config.n_iter):
66
- fps_features, eigenvalues = nc.fit_transform(features, precomputed_sampled_indices=sampled_indices)
67
-
68
- fps_features = to_euclidean(fps_features[:, :config.fps_dim], "cosine")
69
- sampled_indices = torch.sort(fpsample(fps_features, config)).values
37
+ features = features.detach() # float: [... x n x d]
38
+ with default_device(features.device):
39
+ if config.method == "full" or config.num_sample >= features.shape[0]:
40
+ sampled_indices = torch.arange(features.shape[-2]).expand(features.shape[:-1]) # int: [... x n]
70
41
  else:
71
- raise ValueError("sample_method should be 'farthest' or 'random'")
72
- sampled_indices = torch.sort(sampled_indices).values
73
- return sampled_indices.to(features.device)
42
+ # sample
43
+ match config.method:
44
+ case "fps": # default
45
+ sampled_indices = fpsample(to_euclidean(features, distance_type), config)
46
+
47
+ case "random": # not recommended
48
+ mask = torch.all(torch.isfinite(features), dim=-1) # bool: [... x n]
49
+ weights = mask.to(torch.float) + torch.rand(mask.shape) # float: [... x n]
50
+ sampled_indices = torch.topk(weights, k=config.num_sample, dim=-1).indices # int: [... x num_sample]
51
+
52
+ case "fps_recursive":
53
+ features = to_euclidean(features, distance_type) # float: [... x n x d]
54
+ sampled_indices = subsample_features(
55
+ features=features,
56
+ distance_type=distance_type,
57
+ config=SampleConfig(method="fps", num_sample=config.num_sample, fps_dim=config.fps_dim)
58
+ ) # int: [... x num_sample]
59
+ nc = config._ncut_obj
60
+ for _ in range(config.n_iter):
61
+ fps_features, eigenvalues = nc.fit_transform(features, precomputed_sampled_indices=sampled_indices)
62
+
63
+ fps_features = to_euclidean(fps_features[:, :config.fps_dim], "cosine")
64
+ sampled_indices = torch.sort(fpsample(fps_features, config), dim=-1).values
65
+
66
+ case _:
67
+ raise ValueError("sample_method should be 'farthest' or 'random'")
68
+ sampled_indices = torch.sort(sampled_indices, dim=-1).values
69
+ return sampled_indices
74
70
 
75
71
 
76
72
  def fpsample(
77
73
  features: torch.Tensor,
78
74
  config: SampleConfig,
79
75
  ):
80
- # PCA to reduce the dimension
81
- if features.shape[1] > config.fps_dim:
82
- U, S, V = torch.pca_lowrank(features, q=config.fps_dim)
83
- features = U * S
76
+ shape = features.shape[:-2] # ...
77
+ features = features.view((-1, *features.shape[-2:])) # [(...) x n x d]
78
+ bsz = features.shape[0]
79
+
80
+ mask = torch.all(torch.isfinite(features), dim=-1) # bool: [(...) x n]
81
+ count = torch.sum(mask, dim=-1) # int: [(...)]
82
+ order = torch.topk(mask.to(torch.int), k=torch.max(count).item(), dim=-1).indices # int: [(...) x max_count]
83
+
84
+ features = torch.nan_to_num(features[torch.arange(bsz)[:, None], order], nan=0.0) # float: [(...) x max_count x d]
85
+ if features.shape[-1] > config.fps_dim:
86
+ U, S, V = torch.pca_lowrank(features, q=config.fps_dim) # float: [(...) x max_count x fps_dim], [(...) x fps_dim], [(...) x fps_dim x fps_dim]
87
+ features = U * S[..., None, :] # float: [(...) x max_count x fps_dim]
84
88
 
85
89
  try:
86
- return sample_farthest_points(features[None], K=config.num_sample)[1][0]
90
+ sample_indices = sample_farthest_points(
91
+ features, lengths=count, K=config.num_sample
92
+ )[1] # int: [(...) x num_sample]
87
93
  except RuntimeError:
88
- return sample_farthest_points(features[None].cpu(), K=config.num_sample)[1][0].to(features.device)
94
+ original_device = features.device
95
+ alternative_device = "cuda" if original_device == "cpu" else "cpu"
96
+ sample_indices = sample_farthest_points(
97
+ features.to(alternative_device), lengths=count.to(alternative_device), K=config.num_sample
98
+ )[1].to(original_device) # int: [(...) x num_sample]
99
+ sample_indices = torch.gather(order, 1, sample_indices) # int: [(...) x num_sample]
100
+
101
+ return sample_indices.view((*shape, *sample_indices.shape[-1:])) # int: [... x num_sample]
@@ -3,6 +3,9 @@ from typing import Literal
3
3
  import torch
4
4
  import torch.nn.functional as Fn
5
5
 
6
+ from ..common import (
7
+ default_device,
8
+ )
6
9
  from .transformer_mixin import (
7
10
  TorchTransformerMixin,
8
11
  )
@@ -27,51 +30,59 @@ class AxisAlign(TorchTransformerMixin):
27
30
 
28
31
  def fit(self, X: torch.Tensor) -> "AxisAlign":
29
32
  # Normalize eigenvectors
30
- n, d = X.shape
31
- normalized_X = Fn.normalize(X, p=2, dim=-1)
32
-
33
- # Initialize R matrix with the first column from a random row of EigenVectors
34
- self.R = torch.empty((d, d), device=X.device)
35
- self.R[0] = normalized_X[torch.randint(0, n, (), device=X.device)]
36
-
37
- # Loop to populate R with k orthogonal directions
38
- c = torch.zeros((n,), device=X.device)
39
- for i in range(1, d):
40
- c += torch.abs(normalized_X @ self.R[i - 1])
41
- self.R[i] = normalized_X[torch.argmin(c, dim=0)]
42
-
43
- # Iterative optimization loop
44
- idx, prev_objective = None, torch.inf
45
- for _ in range(self.max_iter):
46
- # Discretize the projected eigenvectors
47
- idx = torch.argmax(normalized_X @ self.R.mT, dim=-1)
48
- M = torch.zeros((d, d), device=X.device).index_add_(0, idx, normalized_X)
49
-
50
- # Check for convergence
51
- objective = torch.norm(M)
52
- if torch.abs(objective - prev_objective) < torch.finfo(torch.float32).eps:
53
- break
54
- prev_objective = objective
55
-
56
- # SVD decomposition to compute the next R
57
- U, S, Vh = torch.linalg.svd(M, full_matrices=False)
58
- self.R = U @ Vh
59
-
60
- # Permute the rotation matrix so the dimensions are sorted in descending cluster significance
61
- if self.sort_method == "count":
62
- sort_metric = torch.bincount(idx, minlength=d)
63
- elif self.sort_method == "norm":
64
- rotated_X = X @ self.R.mT
65
- sort_metric = torch.linalg.norm(rotated_X, dim=0)
66
- elif self.sort_method == "marginal_norm":
67
- rotated_X = X @ self.R.mT
68
- sort_metric = torch.zeros((d,), device=X.device).index_add_(0, idx, rotated_X[range(n), idx] ** 2)
69
- else:
70
- raise ValueError(f"Invalid sort method {self.sort_method}.")
71
-
72
- self.R = self.R[torch.argsort(sort_metric, dim=0, descending=True)]
73
- self.is_fitted = True
74
- return self
33
+ with default_device(X.device):
34
+ d = X.shape[-1]
35
+ normalized_X = Fn.normalize(X, p=2, dim=-1) # float: [... x n x d]
36
+
37
+ # Initialize R matrix with the first column from a random row of EigenVectors
38
+ def get_idx(idx: torch.Tensor) -> torch.Tensor:
39
+ return torch.gather(normalized_X, -2, idx[..., None, None].expand([-1] * (X.ndim - 2) + [1, d]))[..., 0, :]
40
+
41
+ self.R = torch.empty((*X.shape[:-2], d, d)) # float: [... x d x d]
42
+ mask = torch.all(torch.isfinite(normalized_X), dim=-1) # bool: [... x n]
43
+ start_idx = torch.argmax(mask.to(torch.float) + torch.rand(mask.shape), dim=-1) # int: [...]
44
+ self.R[..., 0, :] = get_idx(start_idx)
45
+
46
+ # Loop to populate R with k orthogonal directions
47
+ c = torch.zeros(X.shape[:-1]) # float: [... x n]
48
+ for i in range(1, d):
49
+ c += torch.abs(normalized_X @ self.R[..., i - 1, :, None])[..., 0]
50
+ self.R[..., i, :] = get_idx(torch.argmin(c.nan_to_num(nan=torch.inf), dim=-1))
51
+
52
+ # Iterative optimization loop
53
+ normalized_X = torch.nan_to_num(normalized_X, nan=0.0)
54
+ idx, prev_objective = None, torch.inf
55
+ for _ in range(self.max_iter):
56
+ # Discretize the projected eigenvectors
57
+ idx = torch.argmax(normalized_X @ self.R.mT, dim=-1) # int: [... x n]
58
+ M = torch.sum((idx[..., None] == torch.arange(d))[..., None] * normalized_X[..., :, None, :], dim=-3) # float: [... x d x d]
59
+
60
+ # Check for convergence
61
+ objective = torch.norm(M)
62
+ if torch.abs(objective - prev_objective) < torch.finfo(torch.float32).eps:
63
+ break
64
+ prev_objective = objective
65
+
66
+ # SVD decomposition to compute the next R
67
+ U, S, Vh = torch.linalg.svd(M, full_matrices=False)
68
+ self.R = U @ Vh
69
+
70
+ # Permute the rotation matrix so the dimensions are sorted in descending cluster significance
71
+ match self.sort_method:
72
+ case "count":
73
+ sort_metric = torch.sum((idx[..., None] == torch.arange(d)), dim=-2)
74
+ case "norm":
75
+ rotated_X = torch.nan_to_num(X @ self.R.mT, nan=0.0)
76
+ sort_metric = torch.linalg.norm(rotated_X, dim=-2)
77
+ case "marginal_norm":
78
+ rotated_X = torch.nan_to_num(X @ self.R.mT, nan=0.0)
79
+ sort_metric = torch.sum((idx[..., None] == torch.arange(d)) * (torch.gather(rotated_X, -1, idx[..., None]) ** 2), dim=-2)
80
+ case _:
81
+ raise ValueError(f"Invalid sort method {self.sort_method}.")
82
+
83
+ order = torch.argsort(sort_metric, dim=-1, descending=True)
84
+ self.R = torch.gather(self.R, -2, order[..., None].expand([-1] * order.ndim + [d]))
85
+ return self
75
86
 
76
87
  def transform(self, X: torch.Tensor, normalize: bool = True, hard: bool = False) -> torch.Tensor:
77
88
  """
@@ -83,9 +94,9 @@ class AxisAlign(TorchTransformerMixin):
83
94
  torch.Tensor: Discretized eigenvectors, shape (n, k), each row is a one-hot vector.
84
95
  """
85
96
  if normalize:
86
- X = Fn.normalize(X, p=2, dim=1)
97
+ X = Fn.normalize(X, p=2, dim=-1)
87
98
  rotated_X = X @ self.R.mT
88
- return torch.argmax(rotated_X, dim=1) if hard else rotated_X
99
+ return torch.argmax(rotated_X, dim=-1) if hard else rotated_X
89
100
 
90
101
  def fit_transform(self, X: torch.Tensor, normalize: bool = True, hard: bool = False) -> torch.Tensor:
91
102
  return self.fit(X).transform(X, normalize=normalize, hard=hard)
@@ -36,8 +36,6 @@ class TorchTransformerMixin:
36
36
  >>> transformer.fit_transform(X)
37
37
  array([1, 1, 1])
38
38
  """
39
- def __init__(self):
40
- self.is_fitted: bool = False
41
39
 
42
40
  @abstractmethod
43
41
  def fit(self, X: torch.Tensor, **fit_kwargs: Any) -> "TorchTransformerMixin":