nystrom-ncut 0.0.5__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -96,7 +96,6 @@ def distance_from_features(
96
96
  D = D / (2 * features.var(dim=0).sum())
97
97
  else:
98
98
  raise ValueError("distance should be 'cosine' or 'euclidean', 'rbf'")
99
-
100
99
  return D
101
100
 
102
101
 
@@ -184,13 +183,37 @@ def propagate_knn(
184
183
  V_list = []
185
184
  for _v in torch.chunk(inp_features, n_chunks, dim=0):
186
185
  _v = _v.to(device)
187
- _A = affinity_from_features(subgraph_features, _v, affinity_focal_gamma, distance).mT
188
186
 
189
- if knn is not None:
190
- mask = torch.full_like(_A, True, dtype=torch.bool)
191
- mask[torch.arange(len(_v))[:, None], _A.topk(knn, dim=-1, largest=True).indices] = False
192
- _A[mask] = 0.0
193
- _A = F.normalize(_A, p=1, dim=-1)
187
+ # _A = affinity_from_features(subgraph_features, _v, affinity_focal_gamma, distance).mT
188
+ # if knn is not None:
189
+ # mask = torch.full_like(_A, True, dtype=torch.bool)
190
+ # mask[torch.arange(len(_v))[:, None], _A.topk(knn, dim=-1, largest=True).indices] = False
191
+ # _A[mask] = 0.0
192
+ # _A = F.normalize(_A, p=1, dim=-1)
193
+
194
+ if distance == 'cosine':
195
+ _A = _v @ subgraph_features.T
196
+ elif distance == 'euclidean':
197
+ _A = - torch.cdist(_v, subgraph_features, p=2)
198
+ elif distance == 'rbf':
199
+ _A = - torch.cdist(_v, subgraph_features, p=2) ** 2
200
+ else:
201
+ raise ValueError("distance should be 'cosine' or 'euclidean', 'rbf'")
202
+
203
+ # keep topk KNN for each row
204
+ topk_sim, topk_idx = _A.topk(knn, dim=-1, largest=True)
205
+ row_id = torch.arange(topk_idx.shape[0], device=_A.device)[:, None].expand(
206
+ -1, topk_idx.shape[1]
207
+ )
208
+ _A = torch.sparse_coo_tensor(
209
+ torch.stack([row_id, topk_idx], dim=-1).reshape(-1, 2).T,
210
+ topk_sim.reshape(-1),
211
+ size=(_A.shape[0], _A.shape[1]),
212
+ device=_A.device,
213
+ )
214
+ _A = _A.to_dense().to(dtype=subgraph_output.dtype)
215
+ _D = _A.sum(-1)
216
+ _A /= _D[:, None]
194
217
 
195
218
  _V = _A @ subgraph_output
196
219
  if move_output_to_cpu:
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: nystrom_ncut
3
- Version: 0.0.5
3
+ Version: 0.0.6
4
4
  Summary: Normalized Cut and Nyström Approximation
5
5
  Author-email: Huzheng Yang <huze.yann@gmail.com>, Wentinn Liao <wentinn.liao@gmail.com>
6
6
  Project-URL: Documentation, https://github.com/JophiArcana/Nystrom-NCUT/
@@ -2,10 +2,10 @@ nystrom_ncut/__init__.py,sha256=Cww-_OsyQHLKpgw_Wh28_tUOvIMMr7Ey8w-tH7v99xQ,452
2
2
  nystrom_ncut/common.py,sha256=qdR_JwknT9H1Cv5LopwdwZfORFx-O8MLiRI6ZF1Qohc,558
3
3
  nystrom_ncut/ncut_pytorch.py,sha256=wRQXUPBOW2_vutocKf0J19HrFVkBYQePAYUEfotLfx4,11701
4
4
  nystrom_ncut/nystrom.py,sha256=HbwON9pLW3gEZvOmbDJwkQNHolOo1EBvwBPeh2p2uJE,8833
5
- nystrom_ncut/propagation_utils.py,sha256=mD6rZ_mwYjYXs1cp5ZaTK0FrJ4YhyCdoIUrdGRP9k-M,12119
5
+ nystrom_ncut/propagation_utils.py,sha256=OCqnv7P9kzDlwqeJzNWpJ3TdTEpk7AD0rJhX8MazZYs,13061
6
6
  nystrom_ncut/visualize_utils.py,sha256=QmBatlX7Q-ZWF_iJ1zFDnPHFuofz3tCmtoNeeoMPw3U,18558
7
- nystrom_ncut-0.0.5.dist-info/LICENSE,sha256=2bm9uFabQZ3Ykb_SaSU_uUbAj2-htc6WJQmS_65qD00,1073
8
- nystrom_ncut-0.0.5.dist-info/METADATA,sha256=n9zlRYBD02k478INScrj9V9rZ1mhXTylcMjkmQDgl1A,6058
9
- nystrom_ncut-0.0.5.dist-info/WHEEL,sha256=A3WOREP4zgxI0fKrHUG8DC8013e3dK3n7a6HDbcEIwE,91
10
- nystrom_ncut-0.0.5.dist-info/top_level.txt,sha256=j7g_j0S048EvguFFnGgD5Ewd3r2H6klsxd5A4dd-wHw,13
11
- nystrom_ncut-0.0.5.dist-info/RECORD,,
7
+ nystrom_ncut-0.0.6.dist-info/LICENSE,sha256=2bm9uFabQZ3Ykb_SaSU_uUbAj2-htc6WJQmS_65qD00,1073
8
+ nystrom_ncut-0.0.6.dist-info/METADATA,sha256=FD53Ov3g9u4tbBP_Sxxd2hf1yUdg_Hy3ShWq_xGOZFA,6058
9
+ nystrom_ncut-0.0.6.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
10
+ nystrom_ncut-0.0.6.dist-info/top_level.txt,sha256=j7g_j0S048EvguFFnGgD5Ewd3r2H6klsxd5A4dd-wHw,13
11
+ nystrom_ncut-0.0.6.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.7.0)
2
+ Generator: setuptools (75.8.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5