nystrom-ncut 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- nystrom_ncut/ncut_pytorch.py +2 -1
- nystrom_ncut/propagation_utils.py +32 -9
- nystrom_ncut/visualize_utils.py +1 -1
- {nystrom_ncut-0.0.4.dist-info → nystrom_ncut-0.0.6.dist-info}/METADATA +2 -2
- nystrom_ncut-0.0.6.dist-info/RECORD +11 -0
- {nystrom_ncut-0.0.4.dist-info → nystrom_ncut-0.0.6.dist-info}/WHEEL +1 -1
- nystrom_ncut-0.0.4.dist-info/RECORD +0 -11
- {nystrom_ncut-0.0.4.dist-info → nystrom_ncut-0.0.6.dist-info}/LICENSE +0 -0
- {nystrom_ncut-0.0.4.dist-info → nystrom_ncut-0.0.6.dist-info}/top_level.txt +0 -0
nystrom_ncut/ncut_pytorch.py
CHANGED
@@ -173,9 +173,10 @@ class NCUT(OnlineNystrom):
|
|
173
173
|
else:
|
174
174
|
sampled_indices = run_subgraph_sampling(
|
175
175
|
features,
|
176
|
-
|
176
|
+
self.num_sample,
|
177
177
|
sample_method=self.sample_method,
|
178
178
|
)
|
179
|
+
sampled_indices = torch.sort(sampled_indices).values
|
179
180
|
sampled_features = features[sampled_indices]
|
180
181
|
OnlineNystrom.fit(self, sampled_features)
|
181
182
|
|
@@ -11,7 +11,7 @@ from .common import ceildiv, lazy_normalize
|
|
11
11
|
@torch.no_grad()
|
12
12
|
def run_subgraph_sampling(
|
13
13
|
features: torch.Tensor,
|
14
|
-
num_sample: int
|
14
|
+
num_sample: int,
|
15
15
|
max_draw: int = 1000000,
|
16
16
|
sample_method: Literal["farthest", "random"] = "farthest",
|
17
17
|
):
|
@@ -96,7 +96,6 @@ def distance_from_features(
|
|
96
96
|
D = D / (2 * features.var(dim=0).sum())
|
97
97
|
else:
|
98
98
|
raise ValueError("distance should be 'cosine' or 'euclidean', 'rbf'")
|
99
|
-
|
100
99
|
return D
|
101
100
|
|
102
101
|
|
@@ -184,13 +183,37 @@ def propagate_knn(
|
|
184
183
|
V_list = []
|
185
184
|
for _v in torch.chunk(inp_features, n_chunks, dim=0):
|
186
185
|
_v = _v.to(device)
|
187
|
-
_A = affinity_from_features(subgraph_features, _v, affinity_focal_gamma, distance).mT
|
188
186
|
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
_A =
|
187
|
+
# _A = affinity_from_features(subgraph_features, _v, affinity_focal_gamma, distance).mT
|
188
|
+
# if knn is not None:
|
189
|
+
# mask = torch.full_like(_A, True, dtype=torch.bool)
|
190
|
+
# mask[torch.arange(len(_v))[:, None], _A.topk(knn, dim=-1, largest=True).indices] = False
|
191
|
+
# _A[mask] = 0.0
|
192
|
+
# _A = F.normalize(_A, p=1, dim=-1)
|
193
|
+
|
194
|
+
if distance == 'cosine':
|
195
|
+
_A = _v @ subgraph_features.T
|
196
|
+
elif distance == 'euclidean':
|
197
|
+
_A = - torch.cdist(_v, subgraph_features, p=2)
|
198
|
+
elif distance == 'rbf':
|
199
|
+
_A = - torch.cdist(_v, subgraph_features, p=2) ** 2
|
200
|
+
else:
|
201
|
+
raise ValueError("distance should be 'cosine' or 'euclidean', 'rbf'")
|
202
|
+
|
203
|
+
# keep topk KNN for each row
|
204
|
+
topk_sim, topk_idx = _A.topk(knn, dim=-1, largest=True)
|
205
|
+
row_id = torch.arange(topk_idx.shape[0], device=_A.device)[:, None].expand(
|
206
|
+
-1, topk_idx.shape[1]
|
207
|
+
)
|
208
|
+
_A = torch.sparse_coo_tensor(
|
209
|
+
torch.stack([row_id, topk_idx], dim=-1).reshape(-1, 2).T,
|
210
|
+
topk_sim.reshape(-1),
|
211
|
+
size=(_A.shape[0], _A.shape[1]),
|
212
|
+
device=_A.device,
|
213
|
+
)
|
214
|
+
_A = _A.to_dense().to(dtype=subgraph_output.dtype)
|
215
|
+
_D = _A.sum(-1)
|
216
|
+
_A /= _D[:, None]
|
194
217
|
|
195
218
|
_V = _A @ subgraph_output
|
196
219
|
if move_output_to_cpu:
|
@@ -272,7 +295,7 @@ def propagate_eigenvectors(
|
|
272
295
|
# sample subgraph
|
273
296
|
subgraph_indices = run_subgraph_sampling(
|
274
297
|
features,
|
275
|
-
num_sample
|
298
|
+
num_sample,
|
276
299
|
sample_method=sample_method,
|
277
300
|
)
|
278
301
|
|
nystrom_ncut/visualize_utils.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.2
|
2
2
|
Name: nystrom_ncut
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.6
|
4
4
|
Summary: Normalized Cut and Nyström Approximation
|
5
5
|
Author-email: Huzheng Yang <huze.yann@gmail.com>, Wentinn Liao <wentinn.liao@gmail.com>
|
6
6
|
Project-URL: Documentation, https://github.com/JophiArcana/Nystrom-NCUT/
|
@@ -0,0 +1,11 @@
|
|
1
|
+
nystrom_ncut/__init__.py,sha256=Cww-_OsyQHLKpgw_Wh28_tUOvIMMr7Ey8w-tH7v99xQ,452
|
2
|
+
nystrom_ncut/common.py,sha256=qdR_JwknT9H1Cv5LopwdwZfORFx-O8MLiRI6ZF1Qohc,558
|
3
|
+
nystrom_ncut/ncut_pytorch.py,sha256=wRQXUPBOW2_vutocKf0J19HrFVkBYQePAYUEfotLfx4,11701
|
4
|
+
nystrom_ncut/nystrom.py,sha256=HbwON9pLW3gEZvOmbDJwkQNHolOo1EBvwBPeh2p2uJE,8833
|
5
|
+
nystrom_ncut/propagation_utils.py,sha256=OCqnv7P9kzDlwqeJzNWpJ3TdTEpk7AD0rJhX8MazZYs,13061
|
6
|
+
nystrom_ncut/visualize_utils.py,sha256=QmBatlX7Q-ZWF_iJ1zFDnPHFuofz3tCmtoNeeoMPw3U,18558
|
7
|
+
nystrom_ncut-0.0.6.dist-info/LICENSE,sha256=2bm9uFabQZ3Ykb_SaSU_uUbAj2-htc6WJQmS_65qD00,1073
|
8
|
+
nystrom_ncut-0.0.6.dist-info/METADATA,sha256=FD53Ov3g9u4tbBP_Sxxd2hf1yUdg_Hy3ShWq_xGOZFA,6058
|
9
|
+
nystrom_ncut-0.0.6.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
10
|
+
nystrom_ncut-0.0.6.dist-info/top_level.txt,sha256=j7g_j0S048EvguFFnGgD5Ewd3r2H6klsxd5A4dd-wHw,13
|
11
|
+
nystrom_ncut-0.0.6.dist-info/RECORD,,
|
@@ -1,11 +0,0 @@
|
|
1
|
-
nystrom_ncut/__init__.py,sha256=Cww-_OsyQHLKpgw_Wh28_tUOvIMMr7Ey8w-tH7v99xQ,452
|
2
|
-
nystrom_ncut/common.py,sha256=qdR_JwknT9H1Cv5LopwdwZfORFx-O8MLiRI6ZF1Qohc,558
|
3
|
-
nystrom_ncut/ncut_pytorch.py,sha256=8LfznDwhq-WL_vQxbFBFLSzymg9SEDti_zzf9QQLnrA,11651
|
4
|
-
nystrom_ncut/nystrom.py,sha256=HbwON9pLW3gEZvOmbDJwkQNHolOo1EBvwBPeh2p2uJE,8833
|
5
|
-
nystrom_ncut/propagation_utils.py,sha256=pigecB0rAmlbCoMNb8zhCyyNwh3QzkxXEnaBsDRE_ns,12136
|
6
|
-
nystrom_ncut/visualize_utils.py,sha256=oNaDz_Xn12g3knEZZTb-QWVN-wTrnCNE5gn9cu8Xl_U,18569
|
7
|
-
nystrom_ncut-0.0.4.dist-info/LICENSE,sha256=2bm9uFabQZ3Ykb_SaSU_uUbAj2-htc6WJQmS_65qD00,1073
|
8
|
-
nystrom_ncut-0.0.4.dist-info/METADATA,sha256=dog8rG5_vF31_SJS90ruUeJwnrs3bM635m7KSPLht78,6058
|
9
|
-
nystrom_ncut-0.0.4.dist-info/WHEEL,sha256=A3WOREP4zgxI0fKrHUG8DC8013e3dK3n7a6HDbcEIwE,91
|
10
|
-
nystrom_ncut-0.0.4.dist-info/top_level.txt,sha256=j7g_j0S048EvguFFnGgD5Ewd3r2H6klsxd5A4dd-wHw,13
|
11
|
-
nystrom_ncut-0.0.4.dist-info/RECORD,,
|
File without changes
|
File without changes
|