nystrom-ncut 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -173,9 +173,10 @@ class NCUT(OnlineNystrom):
173
173
  else:
174
174
  sampled_indices = run_subgraph_sampling(
175
175
  features,
176
- num_sample=self.num_sample,
176
+ self.num_sample,
177
177
  sample_method=self.sample_method,
178
178
  )
179
+ sampled_indices = torch.sort(sampled_indices).values
179
180
  sampled_features = features[sampled_indices]
180
181
  OnlineNystrom.fit(self, sampled_features)
181
182
 
@@ -11,7 +11,7 @@ from .common import ceildiv, lazy_normalize
11
11
  @torch.no_grad()
12
12
  def run_subgraph_sampling(
13
13
  features: torch.Tensor,
14
- num_sample: int = 300,
14
+ num_sample: int,
15
15
  max_draw: int = 1000000,
16
16
  sample_method: Literal["farthest", "random"] = "farthest",
17
17
  ):
@@ -96,7 +96,6 @@ def distance_from_features(
96
96
  D = D / (2 * features.var(dim=0).sum())
97
97
  else:
98
98
  raise ValueError("distance should be 'cosine' or 'euclidean', 'rbf'")
99
-
100
99
  return D
101
100
 
102
101
 
@@ -184,13 +183,37 @@ def propagate_knn(
184
183
  V_list = []
185
184
  for _v in torch.chunk(inp_features, n_chunks, dim=0):
186
185
  _v = _v.to(device)
187
- _A = affinity_from_features(subgraph_features, _v, affinity_focal_gamma, distance).mT
188
186
 
189
- if knn is not None:
190
- mask = torch.full_like(_A, True, dtype=torch.bool)
191
- mask[torch.arange(len(_v))[:, None], _A.topk(knn, dim=-1, largest=True).indices] = False
192
- _A[mask] = 0.0
193
- _A = F.normalize(_A, p=1, dim=-1)
187
+ # _A = affinity_from_features(subgraph_features, _v, affinity_focal_gamma, distance).mT
188
+ # if knn is not None:
189
+ # mask = torch.full_like(_A, True, dtype=torch.bool)
190
+ # mask[torch.arange(len(_v))[:, None], _A.topk(knn, dim=-1, largest=True).indices] = False
191
+ # _A[mask] = 0.0
192
+ # _A = F.normalize(_A, p=1, dim=-1)
193
+
194
+ if distance == 'cosine':
195
+ _A = _v @ subgraph_features.T
196
+ elif distance == 'euclidean':
197
+ _A = - torch.cdist(_v, subgraph_features, p=2)
198
+ elif distance == 'rbf':
199
+ _A = - torch.cdist(_v, subgraph_features, p=2) ** 2
200
+ else:
201
+ raise ValueError("distance should be 'cosine' or 'euclidean', 'rbf'")
202
+
203
+ # keep topk KNN for each row
204
+ topk_sim, topk_idx = _A.topk(knn, dim=-1, largest=True)
205
+ row_id = torch.arange(topk_idx.shape[0], device=_A.device)[:, None].expand(
206
+ -1, topk_idx.shape[1]
207
+ )
208
+ _A = torch.sparse_coo_tensor(
209
+ torch.stack([row_id, topk_idx], dim=-1).reshape(-1, 2).T,
210
+ topk_sim.reshape(-1),
211
+ size=(_A.shape[0], _A.shape[1]),
212
+ device=_A.device,
213
+ )
214
+ _A = _A.to_dense().to(dtype=subgraph_output.dtype)
215
+ _D = _A.sum(-1)
216
+ _A /= _D[:, None]
194
217
 
195
218
  _V = _A @ subgraph_output
196
219
  if move_output_to_cpu:
@@ -272,7 +295,7 @@ def propagate_eigenvectors(
272
295
  # sample subgraph
273
296
  subgraph_indices = run_subgraph_sampling(
274
297
  features,
275
- num_sample=num_sample,
298
+ num_sample,
276
299
  sample_method=sample_method,
277
300
  )
278
301
 
@@ -34,7 +34,7 @@ def _rgb_with_dimensionality_reduction(
34
34
  ) -> Tuple[torch.Tensor, torch.Tensor]:
35
35
  subgraph_indices = run_subgraph_sampling(
36
36
  features,
37
- num_sample=num_sample,
37
+ num_sample,
38
38
  sample_method="farthest",
39
39
  )
40
40
 
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: nystrom_ncut
3
- Version: 0.0.4
3
+ Version: 0.0.6
4
4
  Summary: Normalized Cut and Nyström Approximation
5
5
  Author-email: Huzheng Yang <huze.yann@gmail.com>, Wentinn Liao <wentinn.liao@gmail.com>
6
6
  Project-URL: Documentation, https://github.com/JophiArcana/Nystrom-NCUT/
@@ -0,0 +1,11 @@
1
+ nystrom_ncut/__init__.py,sha256=Cww-_OsyQHLKpgw_Wh28_tUOvIMMr7Ey8w-tH7v99xQ,452
2
+ nystrom_ncut/common.py,sha256=qdR_JwknT9H1Cv5LopwdwZfORFx-O8MLiRI6ZF1Qohc,558
3
+ nystrom_ncut/ncut_pytorch.py,sha256=wRQXUPBOW2_vutocKf0J19HrFVkBYQePAYUEfotLfx4,11701
4
+ nystrom_ncut/nystrom.py,sha256=HbwON9pLW3gEZvOmbDJwkQNHolOo1EBvwBPeh2p2uJE,8833
5
+ nystrom_ncut/propagation_utils.py,sha256=OCqnv7P9kzDlwqeJzNWpJ3TdTEpk7AD0rJhX8MazZYs,13061
6
+ nystrom_ncut/visualize_utils.py,sha256=QmBatlX7Q-ZWF_iJ1zFDnPHFuofz3tCmtoNeeoMPw3U,18558
7
+ nystrom_ncut-0.0.6.dist-info/LICENSE,sha256=2bm9uFabQZ3Ykb_SaSU_uUbAj2-htc6WJQmS_65qD00,1073
8
+ nystrom_ncut-0.0.6.dist-info/METADATA,sha256=FD53Ov3g9u4tbBP_Sxxd2hf1yUdg_Hy3ShWq_xGOZFA,6058
9
+ nystrom_ncut-0.0.6.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
10
+ nystrom_ncut-0.0.6.dist-info/top_level.txt,sha256=j7g_j0S048EvguFFnGgD5Ewd3r2H6klsxd5A4dd-wHw,13
11
+ nystrom_ncut-0.0.6.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.7.0)
2
+ Generator: setuptools (75.8.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,11 +0,0 @@
1
- nystrom_ncut/__init__.py,sha256=Cww-_OsyQHLKpgw_Wh28_tUOvIMMr7Ey8w-tH7v99xQ,452
2
- nystrom_ncut/common.py,sha256=qdR_JwknT9H1Cv5LopwdwZfORFx-O8MLiRI6ZF1Qohc,558
3
- nystrom_ncut/ncut_pytorch.py,sha256=8LfznDwhq-WL_vQxbFBFLSzymg9SEDti_zzf9QQLnrA,11651
4
- nystrom_ncut/nystrom.py,sha256=HbwON9pLW3gEZvOmbDJwkQNHolOo1EBvwBPeh2p2uJE,8833
5
- nystrom_ncut/propagation_utils.py,sha256=pigecB0rAmlbCoMNb8zhCyyNwh3QzkxXEnaBsDRE_ns,12136
6
- nystrom_ncut/visualize_utils.py,sha256=oNaDz_Xn12g3knEZZTb-QWVN-wTrnCNE5gn9cu8Xl_U,18569
7
- nystrom_ncut-0.0.4.dist-info/LICENSE,sha256=2bm9uFabQZ3Ykb_SaSU_uUbAj2-htc6WJQmS_65qD00,1073
8
- nystrom_ncut-0.0.4.dist-info/METADATA,sha256=dog8rG5_vF31_SJS90ruUeJwnrs3bM635m7KSPLht78,6058
9
- nystrom_ncut-0.0.4.dist-info/WHEEL,sha256=A3WOREP4zgxI0fKrHUG8DC8013e3dK3n7a6HDbcEIwE,91
10
- nystrom_ncut-0.0.4.dist-info/top_level.txt,sha256=j7g_j0S048EvguFFnGgD5Ewd3r2H6klsxd5A4dd-wHw,13
11
- nystrom_ncut-0.0.4.dist-info/RECORD,,