nystrom-ncut 0.0.5__py3-none-any.whl → 0.0.6__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -96,7 +96,6 @@ def distance_from_features(
96
96
  D = D / (2 * features.var(dim=0).sum())
97
97
  else:
98
98
  raise ValueError("distance should be 'cosine' or 'euclidean', 'rbf'")
99
-
100
99
  return D
101
100
 
102
101
 
@@ -184,13 +183,37 @@ def propagate_knn(
184
183
  V_list = []
185
184
  for _v in torch.chunk(inp_features, n_chunks, dim=0):
186
185
  _v = _v.to(device)
187
- _A = affinity_from_features(subgraph_features, _v, affinity_focal_gamma, distance).mT
188
186
 
189
- if knn is not None:
190
- mask = torch.full_like(_A, True, dtype=torch.bool)
191
- mask[torch.arange(len(_v))[:, None], _A.topk(knn, dim=-1, largest=True).indices] = False
192
- _A[mask] = 0.0
193
- _A = F.normalize(_A, p=1, dim=-1)
187
+ # _A = affinity_from_features(subgraph_features, _v, affinity_focal_gamma, distance).mT
188
+ # if knn is not None:
189
+ # mask = torch.full_like(_A, True, dtype=torch.bool)
190
+ # mask[torch.arange(len(_v))[:, None], _A.topk(knn, dim=-1, largest=True).indices] = False
191
+ # _A[mask] = 0.0
192
+ # _A = F.normalize(_A, p=1, dim=-1)
193
+
194
+ if distance == 'cosine':
195
+ _A = _v @ subgraph_features.T
196
+ elif distance == 'euclidean':
197
+ _A = - torch.cdist(_v, subgraph_features, p=2)
198
+ elif distance == 'rbf':
199
+ _A = - torch.cdist(_v, subgraph_features, p=2) ** 2
200
+ else:
201
+ raise ValueError("distance should be 'cosine' or 'euclidean', 'rbf'")
202
+
203
+ # keep topk KNN for each row
204
+ topk_sim, topk_idx = _A.topk(knn, dim=-1, largest=True)
205
+ row_id = torch.arange(topk_idx.shape[0], device=_A.device)[:, None].expand(
206
+ -1, topk_idx.shape[1]
207
+ )
208
+ _A = torch.sparse_coo_tensor(
209
+ torch.stack([row_id, topk_idx], dim=-1).reshape(-1, 2).T,
210
+ topk_sim.reshape(-1),
211
+ size=(_A.shape[0], _A.shape[1]),
212
+ device=_A.device,
213
+ )
214
+ _A = _A.to_dense().to(dtype=subgraph_output.dtype)
215
+ _D = _A.sum(-1)
216
+ _A /= _D[:, None]
194
217
 
195
218
  _V = _A @ subgraph_output
196
219
  if move_output_to_cpu:
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: nystrom_ncut
3
- Version: 0.0.5
3
+ Version: 0.0.6
4
4
  Summary: Normalized Cut and Nyström Approximation
5
5
  Author-email: Huzheng Yang <huze.yann@gmail.com>, Wentinn Liao <wentinn.liao@gmail.com>
6
6
  Project-URL: Documentation, https://github.com/JophiArcana/Nystrom-NCUT/
@@ -2,10 +2,10 @@ nystrom_ncut/__init__.py,sha256=Cww-_OsyQHLKpgw_Wh28_tUOvIMMr7Ey8w-tH7v99xQ,452
2
2
  nystrom_ncut/common.py,sha256=qdR_JwknT9H1Cv5LopwdwZfORFx-O8MLiRI6ZF1Qohc,558
3
3
  nystrom_ncut/ncut_pytorch.py,sha256=wRQXUPBOW2_vutocKf0J19HrFVkBYQePAYUEfotLfx4,11701
4
4
  nystrom_ncut/nystrom.py,sha256=HbwON9pLW3gEZvOmbDJwkQNHolOo1EBvwBPeh2p2uJE,8833
5
- nystrom_ncut/propagation_utils.py,sha256=mD6rZ_mwYjYXs1cp5ZaTK0FrJ4YhyCdoIUrdGRP9k-M,12119
5
+ nystrom_ncut/propagation_utils.py,sha256=OCqnv7P9kzDlwqeJzNWpJ3TdTEpk7AD0rJhX8MazZYs,13061
6
6
  nystrom_ncut/visualize_utils.py,sha256=QmBatlX7Q-ZWF_iJ1zFDnPHFuofz3tCmtoNeeoMPw3U,18558
7
- nystrom_ncut-0.0.5.dist-info/LICENSE,sha256=2bm9uFabQZ3Ykb_SaSU_uUbAj2-htc6WJQmS_65qD00,1073
8
- nystrom_ncut-0.0.5.dist-info/METADATA,sha256=n9zlRYBD02k478INScrj9V9rZ1mhXTylcMjkmQDgl1A,6058
9
- nystrom_ncut-0.0.5.dist-info/WHEEL,sha256=A3WOREP4zgxI0fKrHUG8DC8013e3dK3n7a6HDbcEIwE,91
10
- nystrom_ncut-0.0.5.dist-info/top_level.txt,sha256=j7g_j0S048EvguFFnGgD5Ewd3r2H6klsxd5A4dd-wHw,13
11
- nystrom_ncut-0.0.5.dist-info/RECORD,,
7
+ nystrom_ncut-0.0.6.dist-info/LICENSE,sha256=2bm9uFabQZ3Ykb_SaSU_uUbAj2-htc6WJQmS_65qD00,1073
8
+ nystrom_ncut-0.0.6.dist-info/METADATA,sha256=FD53Ov3g9u4tbBP_Sxxd2hf1yUdg_Hy3ShWq_xGOZFA,6058
9
+ nystrom_ncut-0.0.6.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
10
+ nystrom_ncut-0.0.6.dist-info/top_level.txt,sha256=j7g_j0S048EvguFFnGgD5Ewd3r2H6klsxd5A4dd-wHw,13
11
+ nystrom_ncut-0.0.6.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.7.0)
2
+ Generator: setuptools (75.8.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5