ncut-pytorch 3.0.0.dev6__tar.gz → 3.0.0.dev8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/PKG-INFO +1 -1
- {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/color/coloring.py +8 -8
- {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/color/mspace.py +132 -267
- {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/ncuts/ncut_nystrom.py +1 -1
- {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/predictor/predictor.py +3 -3
- {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/predictor/vision_predictor.py +1 -1
- {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/utils/grad.py +9 -7
- {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch.egg-info/PKG-INFO +1 -1
- {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/pyproject.toml +1 -1
- {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/LICENSE +0 -0
- {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/README.md +0 -0
- {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/__init__.py +0 -0
- {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/color/__init__.py +0 -0
- {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/color/mspace_nopl.py +0 -0
- {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/ncut.py +0 -0
- {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/ncuts/__init__.py +0 -0
- {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/ncuts/ncut_click.py +0 -0
- {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/ncuts/ncut_kway.py +0 -0
- {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/predictor/__init__.py +0 -0
- {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/predictor/dino/__init__.py +0 -0
- {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/predictor/dino/api.py +0 -0
- {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/predictor/dino/dinov3.py +0 -0
- {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/predictor/dino/hires_dino.py +0 -0
- {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/predictor/dino/lowres_dino.py +0 -0
- {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/predictor/dino/patch.py +0 -0
- {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/predictor/dino/transform.py +0 -0
- {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/predictor/dino_predictor.py +0 -0
- {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/predictor/jafar_predictor.py +0 -0
- {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/utils/__init__.py +0 -0
- {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/utils/device.py +0 -0
- {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/utils/math.py +0 -0
- {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/utils/sample.py +0 -0
- {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/utils/sigma.py +0 -0
- {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/utils/torch_mod.py +0 -0
- {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch.egg-info/SOURCES.txt +0 -0
- {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch.egg-info/dependency_links.txt +0 -0
- {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch.egg-info/requires.txt +0 -0
- {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch.egg-info/top_level.txt +0 -0
- {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/setup.cfg +0 -0
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
__all__ = ["mspace_color", "tsne_color", "umap_color", "umap_sphere_color", "rotate_rgb_cube", "convert_to_lab_color"]
|
|
2
2
|
|
|
3
3
|
import warnings
|
|
4
|
-
from typing import Any, Callable, Dict, Literal, Tuple, Optional
|
|
4
|
+
from typing import Any, Callable, Dict, Literal, Tuple, Optional, List
|
|
5
5
|
|
|
6
6
|
from numba.core.types import none
|
|
7
7
|
import numpy as np
|
|
@@ -19,7 +19,7 @@ def _identity(X: torch.Tensor) -> torch.Tensor:
|
|
|
19
19
|
def mspace_color(
|
|
20
20
|
X: torch.Tensor,
|
|
21
21
|
q: float = 0.95,
|
|
22
|
-
|
|
22
|
+
n_eig_list: Optional[List[int]] = [4, 16, 64],
|
|
23
23
|
n_dim: int = 3,
|
|
24
24
|
training_steps: int = 1000,
|
|
25
25
|
progress_bar: bool = False,
|
|
@@ -34,15 +34,15 @@ def mspace_color(
|
|
|
34
34
|
|
|
35
35
|
low_dim_embedding = mspace_viz_transform(
|
|
36
36
|
X=X,
|
|
37
|
-
|
|
38
|
-
|
|
37
|
+
n_eig_list=n_eig_list,
|
|
38
|
+
z_dim=n_dim,
|
|
39
39
|
training_steps=training_steps,
|
|
40
|
-
decoder_training_steps=0,
|
|
41
40
|
progress_bar=progress_bar,
|
|
42
|
-
|
|
43
|
-
|
|
41
|
+
flag_loss_mode='z',
|
|
42
|
+
flag_loss=1.0,
|
|
43
|
+
recon_loss=.001,
|
|
44
44
|
zero_center_loss=0.001,
|
|
45
|
-
repulsion_loss=0.
|
|
45
|
+
repulsion_loss=0.001,
|
|
46
46
|
**kwargs)
|
|
47
47
|
|
|
48
48
|
rgb = rgb_from_nd_colormap(low_dim_embedding, q=q)
|
|
@@ -2,9 +2,10 @@ __all__ = ["train_mspace_model", "mspace_viz_transform"]
|
|
|
2
2
|
|
|
3
3
|
import logging
|
|
4
4
|
from collections import defaultdict
|
|
5
|
-
from functools import partial
|
|
5
|
+
from functools import partial, wraps
|
|
6
6
|
import warnings
|
|
7
7
|
import pytorch_lightning as pl
|
|
8
|
+
from typing import List
|
|
8
9
|
|
|
9
10
|
import numpy as np
|
|
10
11
|
import torch
|
|
@@ -15,58 +16,21 @@ from tqdm import tqdm
|
|
|
15
16
|
|
|
16
17
|
from ncut_pytorch.utils.device import auto_device
|
|
17
18
|
from ncut_pytorch import ncut_fn
|
|
18
|
-
from ncut_pytorch.utils.math import
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
right = _eigvec_hat @ _eigvec_hat.T
|
|
26
|
-
left = left * weight[:, None] * weight[None, :]
|
|
27
|
-
right = right * weight[:, None] * weight[None, :]
|
|
28
|
-
loss = F.l1_loss(left, right)
|
|
29
|
-
return loss
|
|
30
|
-
|
|
31
|
-
def flag_space_loss(eigvec_gt, eigvec_hat, n_eig, start=2, step_mult=2, weight=None):
|
|
32
|
-
if torch.all(eigvec_gt == 0) or torch.all(eigvec_hat == 0):
|
|
33
|
-
return torch.tensor(0, device=eigvec_gt.device)
|
|
34
|
-
|
|
35
|
-
if weight is None:
|
|
36
|
-
weight = torch.ones(eigvec_gt.shape[0], device=eigvec_gt.device, dtype=eigvec_gt.dtype)
|
|
37
|
-
|
|
38
|
-
loss = 0
|
|
39
|
-
n_eig = start // step_mult
|
|
40
|
-
while True:
|
|
41
|
-
n_eig *= step_mult
|
|
42
|
-
loss += _flag_ncut_loss(eigvec_gt, eigvec_hat, n_eig, weight)
|
|
43
|
-
if n_eig > eigvec_gt.shape[1] or n_eig > eigvec_hat.shape[1]:
|
|
44
|
-
break
|
|
45
|
-
return loss
|
|
46
|
-
|
|
47
|
-
def filter_closeby_eigval(eigvec, eigval, threshold=1e-3):
|
|
48
|
-
# filter out eigvals that are too close to each other
|
|
49
|
-
# so the gradient is more stable
|
|
50
|
-
eigval_diff = torch.diff(eigval).abs()
|
|
51
|
-
keep_idx = torch.where(eigval_diff > threshold)[0]
|
|
52
|
-
return eigvec[:, keep_idx], eigval[keep_idx]
|
|
53
|
-
|
|
54
|
-
def ncut_wrapper(features, n_eig, sigma=None):
|
|
55
|
-
# TODO: gradient is not stable.
|
|
56
|
-
|
|
57
|
-
# features.requires_grad_(True)
|
|
58
|
-
sigma = sigma or features.std(0).sum().item()
|
|
59
|
-
eigvec, eigval = ncut_fn(features, n_eig, sigma=sigma)
|
|
19
|
+
from ncut_pytorch.utils.math import normalize_affinity, rbf_affinity, cosine_affinity
|
|
20
|
+
from ncut_pytorch.utils.grad import eigvec_outer_product
|
|
21
|
+
from ncut_pytorch.utils.sigma import find_sigma_by_degree
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def get_eigvec_outer_product(features: torch.Tensor, n_eig_list: List[int]) -> torch.Tensor:
|
|
25
|
+
sigma = find_sigma_by_degree(features)
|
|
60
26
|
W = rbf_affinity(features, sigma=sigma)
|
|
61
|
-
# W = cosine_affinity(features, sigma=1.0)
|
|
62
27
|
A = normalize_affinity(W)
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
return eigvec, eigval
|
|
28
|
+
eigval_masks = torch.zeros(len(n_eig_list), A.shape[0], dtype=torch.bool)
|
|
29
|
+
for i, n_eig in enumerate(n_eig_list):
|
|
30
|
+
eigval_masks[i, :n_eig] = True
|
|
31
|
+
P = eigvec_outer_product(A, eigval_masks)
|
|
32
|
+
return P
|
|
33
|
+
|
|
70
34
|
|
|
71
35
|
|
|
72
36
|
class MovingMinMax(nn.Module):
|
|
@@ -112,12 +76,12 @@ class MLP(nn.Module):
|
|
|
112
76
|
|
|
113
77
|
|
|
114
78
|
class MspaceAutoEncoder(nn.Module):
|
|
115
|
-
def __init__(self, in_dim, out_dim,
|
|
79
|
+
def __init__(self, in_dim, out_dim, z_dim, n_layer=4, latent_dim=256, encoder_activation='gelu', decoder_activation='gelu', final_activation='identity'):
|
|
116
80
|
super().__init__()
|
|
117
|
-
self.encoder = MLP(in_dim,
|
|
81
|
+
self.encoder = MLP(in_dim, z_dim, n_layer, latent_dim, encoder_activation, final_activation)
|
|
118
82
|
self.decoder = nn.Sequential(
|
|
119
|
-
MovingMinMax(
|
|
120
|
-
MLP(
|
|
83
|
+
MovingMinMax(z_dim),
|
|
84
|
+
MLP(z_dim, out_dim, n_layer, latent_dim, decoder_activation)
|
|
121
85
|
)
|
|
122
86
|
|
|
123
87
|
|
|
@@ -260,34 +224,29 @@ class TrainDecoder(pl.LightningModule):
|
|
|
260
224
|
class TrainEncoder(pl.LightningModule):
|
|
261
225
|
N_ITER_PER_STEP = 10
|
|
262
226
|
|
|
263
|
-
def __init__(self, in_dim, out_dim,
|
|
227
|
+
def __init__(self, in_dim, out_dim, z_dim=2, n_eig_list=[4, 16, 64],
|
|
264
228
|
n_layer=4, latent_dim=256,
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
repulsion_loss=
|
|
268
|
-
|
|
269
|
-
lr=0.001, progress_bar=True, training_steps=500,
|
|
229
|
+
flag_loss_mode='z_eigvec', # 'z_eigvec' or 'z'
|
|
230
|
+
flag_loss=1.0, recon_loss=1e-3,
|
|
231
|
+
repulsion_loss=1e-3, zero_center_loss=1e-3,
|
|
232
|
+
lr=1e-3, progress_bar=True, training_steps=1000,
|
|
270
233
|
encoder_activation='gelu', decoder_activation='gelu',
|
|
271
234
|
final_activation='identity',
|
|
272
|
-
log_grad_norm=
|
|
235
|
+
log_grad_norm=True,
|
|
273
236
|
**kwargs):
|
|
274
237
|
super().__init__()
|
|
275
238
|
|
|
276
|
-
self.mspace_ae = MspaceAutoEncoder(in_dim, out_dim,
|
|
239
|
+
self.mspace_ae = MspaceAutoEncoder(in_dim, out_dim, z_dim, n_layer, latent_dim, encoder_activation, decoder_activation, final_activation)
|
|
277
240
|
|
|
278
241
|
self.loss_history = defaultdict(list)
|
|
279
242
|
|
|
280
|
-
self.
|
|
243
|
+
self.flag_loss_mode = flag_loss_mode
|
|
244
|
+
self.flag_loss = flag_loss
|
|
281
245
|
self.recon_loss = recon_loss
|
|
282
|
-
self.riemann_curvature_loss = riemann_curvature_loss
|
|
283
|
-
self.axis_align_loss = axis_align_loss
|
|
284
246
|
self.repulsion_loss = repulsion_loss
|
|
285
|
-
self.attraction_loss = attraction_loss
|
|
286
|
-
self.boundary_loss = boundary_loss
|
|
287
247
|
self.zero_center_loss = zero_center_loss
|
|
288
248
|
self.lr = lr
|
|
289
|
-
self.
|
|
290
|
-
self.n_elbow = n_elbow
|
|
249
|
+
self.n_eig_list = n_eig_list
|
|
291
250
|
|
|
292
251
|
self.progress_bar = progress_bar
|
|
293
252
|
self.training_steps = training_steps
|
|
@@ -303,11 +262,11 @@ class TrainEncoder(pl.LightningModule):
|
|
|
303
262
|
return self.mspace_ae.encoder(x)
|
|
304
263
|
|
|
305
264
|
|
|
306
|
-
def _log_loss(self, loss, name, log_grad_norm=False):
|
|
307
|
-
if self.logger is None:
|
|
308
|
-
return
|
|
309
|
-
self.logger.log_metrics({f'loss/{name}': loss.item()}, step=self.global_step)
|
|
310
|
-
self.loss_history[f"loss/{name}"].append(loss.item())
|
|
265
|
+
def _log_loss(self, loss, name, log_grad_norm=False, iteration=0):
|
|
266
|
+
# if self.logger is None:
|
|
267
|
+
# return
|
|
268
|
+
# self.logger.log_metrics({f'loss/{name}': loss.item()}, step=self.global_step)
|
|
269
|
+
# self.loss_history[f"loss/{name}"].append(loss.item())
|
|
311
270
|
if log_grad_norm:
|
|
312
271
|
grad_norm = 0
|
|
313
272
|
self.manual_backward(loss, retain_graph=True)
|
|
@@ -315,8 +274,10 @@ class TrainEncoder(pl.LightningModule):
|
|
|
315
274
|
if param.grad is not None:
|
|
316
275
|
grad_norm += param.grad.norm().item() ** 2
|
|
317
276
|
grad_norm = grad_norm ** 0.5
|
|
318
|
-
self.logger.log_metrics({f'grad/{name}': grad_norm}, step=self.global_step)
|
|
319
|
-
self.loss_history[f"grad/{name}"].append(grad_norm)
|
|
277
|
+
# self.logger.log_metrics({f'grad/{name}': grad_norm}, step=self.global_step)
|
|
278
|
+
# self.loss_history[f"grad/{name}"].append(grad_norm)
|
|
279
|
+
if iteration % 10 == 0:
|
|
280
|
+
print(f"loss/{name} grad_norm: {grad_norm:.2e}, iteration: {iteration}")
|
|
320
281
|
self.optimizers().zero_grad()
|
|
321
282
|
|
|
322
283
|
def training_step(self, batch):
|
|
@@ -328,12 +289,10 @@ class TrainEncoder(pl.LightningModule):
|
|
|
328
289
|
|
|
329
290
|
input_feats, output_feats = batch
|
|
330
291
|
|
|
331
|
-
|
|
332
|
-
# with torch.no_grad():
|
|
292
|
+
P_gt = {}
|
|
333
293
|
# Compute eigvec_gt only once for each set of iterations
|
|
334
|
-
if self.
|
|
335
|
-
|
|
336
|
-
stored_eigvec_gt = eigvec_gt
|
|
294
|
+
if self.flag_loss != 0:
|
|
295
|
+
P_gt = get_eigvec_outer_product(input_feats, self.n_eig_list)
|
|
337
296
|
|
|
338
297
|
|
|
339
298
|
# Run the same batch 10 times, updating parameters after each iteration
|
|
@@ -344,28 +303,36 @@ class TrainEncoder(pl.LightningModule):
|
|
|
344
303
|
log_grad_norm = iteration == 0 and self.log_grad_norm
|
|
345
304
|
|
|
346
305
|
total_loss = 0
|
|
347
|
-
if self.
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
self.
|
|
306
|
+
if self.flag_loss != 0:
|
|
307
|
+
if self.flag_loss_mode == 'z_eigvec':
|
|
308
|
+
P_hat = get_eigvec_outer_product(feats_compressed, self.n_eig_list)
|
|
309
|
+
flag_loss = F.mse_loss(P_gt, P_hat)
|
|
310
|
+
elif self.flag_loss_mode == 'z':
|
|
311
|
+
P_hat = feats_compressed @ feats_compressed.T
|
|
312
|
+
flag_loss = F.mse_loss(P_gt.mean(0), P_hat)
|
|
313
|
+
flag_loss = flag_loss * self.flag_loss
|
|
314
|
+
total_loss += flag_loss
|
|
315
|
+
self._log_loss(flag_loss, "flag_loss", log_grad_norm=log_grad_norm, iteration=self.trainer.global_step)
|
|
355
316
|
|
|
356
317
|
if self.recon_loss > 0:
|
|
357
318
|
feats_uncompressed = self.mspace_ae.decoder(feats_compressed)
|
|
358
|
-
recon_loss = F.
|
|
319
|
+
recon_loss = F.mse_loss(output_feats, feats_uncompressed)
|
|
359
320
|
recon_loss = recon_loss * self.recon_loss
|
|
360
321
|
total_loss += recon_loss
|
|
361
|
-
self._log_loss(recon_loss, "
|
|
322
|
+
self._log_loss(recon_loss, "recon_loss", log_grad_norm=log_grad_norm, iteration=self.trainer.global_step)
|
|
362
323
|
|
|
363
324
|
if self.zero_center_loss > 0:
|
|
364
325
|
# zero_center_loss = feats_compressed.abs().mean()
|
|
365
326
|
zero_center_loss = (feats_compressed ** 2).mean()
|
|
366
327
|
zero_center_loss = zero_center_loss * self.zero_center_loss
|
|
367
328
|
total_loss += zero_center_loss
|
|
368
|
-
self._log_loss(zero_center_loss, "
|
|
329
|
+
self._log_loss(zero_center_loss, "zero_center_loss", log_grad_norm=log_grad_norm, iteration=self.trainer.global_step)
|
|
330
|
+
|
|
331
|
+
if self.repulsion_loss > 0:
|
|
332
|
+
repulsion_loss = compute_repulsion_loss(feats_compressed)
|
|
333
|
+
repulsion_loss = repulsion_loss * self.repulsion_loss
|
|
334
|
+
total_loss += repulsion_loss
|
|
335
|
+
self._log_loss(repulsion_loss, "repulsion_loss", log_grad_norm=log_grad_norm, iteration=self.trainer.global_step)
|
|
369
336
|
|
|
370
337
|
# Log the loss for this iteration
|
|
371
338
|
if self.logger is not None:
|
|
@@ -387,86 +354,6 @@ class TrainEncoder(pl.LightningModule):
|
|
|
387
354
|
return optimizer
|
|
388
355
|
|
|
389
356
|
|
|
390
|
-
# Moving Average Callback
|
|
391
|
-
class BestModelsAvgCallback(pl.Callback):
|
|
392
|
-
"""
|
|
393
|
-
Callback to store the top models with the lowest total loss and average them at the end.
|
|
394
|
-
"""
|
|
395
|
-
def __init__(self, top_k=10):
|
|
396
|
-
super().__init__()
|
|
397
|
-
self.top_k = top_k
|
|
398
|
-
self.best_models = [] # List of tuples (loss, params)
|
|
399
|
-
self.best_loss = float('inf')
|
|
400
|
-
|
|
401
|
-
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
|
402
|
-
"""Store the model if it has a lower loss than the current best models."""
|
|
403
|
-
|
|
404
|
-
# skip if the global step is on the frist 80% of the training steps
|
|
405
|
-
if trainer.global_step < trainer.max_steps * 0.8:
|
|
406
|
-
return
|
|
407
|
-
|
|
408
|
-
# Get current model parameters
|
|
409
|
-
current_params = {}
|
|
410
|
-
for name, param in pl_module.named_parameters():
|
|
411
|
-
current_params[name] = param.data.clone()
|
|
412
|
-
|
|
413
|
-
# Get the current total loss from the outputs
|
|
414
|
-
# The outputs from training_step is the total loss
|
|
415
|
-
current_loss = outputs['loss'].item()
|
|
416
|
-
|
|
417
|
-
# If we have fewer than top_k models or this is a better model
|
|
418
|
-
if len(self.best_models) < self.top_k or current_loss < self.best_loss:
|
|
419
|
-
# Add to best models
|
|
420
|
-
self.best_models.append((current_loss, current_params))
|
|
421
|
-
|
|
422
|
-
# Sort by loss (ascending)
|
|
423
|
-
self.best_models.sort(key=lambda x: x[0])
|
|
424
|
-
|
|
425
|
-
# Keep only top_k models
|
|
426
|
-
if len(self.best_models) > self.top_k:
|
|
427
|
-
self.best_models = self.best_models[:self.top_k]
|
|
428
|
-
|
|
429
|
-
# Update best loss
|
|
430
|
-
self.best_loss = self.best_models[0][0]
|
|
431
|
-
|
|
432
|
-
def on_train_end(self, trainer, pl_module):
|
|
433
|
-
"""Average the best model parameters and apply them to the model at the end of training."""
|
|
434
|
-
if self.best_models:
|
|
435
|
-
# Initialize averaged parameters
|
|
436
|
-
avg_params = {}
|
|
437
|
-
|
|
438
|
-
# Get parameter names from the first model
|
|
439
|
-
param_names = self.best_models[0][1].keys()
|
|
440
|
-
|
|
441
|
-
# Initialize with zeros of the same shape
|
|
442
|
-
for name in param_names:
|
|
443
|
-
avg_params[name] = torch.zeros_like(self.best_models[0][1][name])
|
|
444
|
-
|
|
445
|
-
# Sum all parameters from best models
|
|
446
|
-
for _, params in self.best_models:
|
|
447
|
-
for name in param_names:
|
|
448
|
-
avg_params[name] += params[name]
|
|
449
|
-
|
|
450
|
-
# Divide by the number of models to get the average
|
|
451
|
-
for name in param_names:
|
|
452
|
-
avg_params[name] /= len(self.best_models)
|
|
453
|
-
|
|
454
|
-
# Apply averaged parameters to the model
|
|
455
|
-
for name, param in pl_module.named_parameters():
|
|
456
|
-
if name in avg_params:
|
|
457
|
-
param.data.copy_(avg_params[name])
|
|
458
|
-
|
|
459
|
-
# Log the best loss and average loss
|
|
460
|
-
if pl_module.logger is not None:
|
|
461
|
-
pl_module.logger.log_metrics({
|
|
462
|
-
'best_loss': self.best_loss,
|
|
463
|
-
'avg_loss': sum(loss for loss, _ in self.best_models) / len(self.best_models)
|
|
464
|
-
}, step=trainer.global_step)
|
|
465
|
-
|
|
466
|
-
self.best_models = []
|
|
467
|
-
self.best_loss = float('inf')
|
|
468
|
-
|
|
469
|
-
|
|
470
357
|
def compute_repulsion_loss(
|
|
471
358
|
points: torch.Tensor, # [N, D]
|
|
472
359
|
) -> torch.Tensor:
|
|
@@ -475,115 +362,92 @@ def compute_repulsion_loss(
|
|
|
475
362
|
mask = torch.eye(points.shape[0], device=points.device).bool()
|
|
476
363
|
dist_matrix = dist_matrix + mask * 1e10
|
|
477
364
|
nearest_dists, _ = torch.min(dist_matrix, dim=1)
|
|
478
|
-
repulsion = 1.0 / (nearest_dists +
|
|
365
|
+
repulsion = 1.0 / (nearest_dists + 1)
|
|
479
366
|
return torch.mean(repulsion)
|
|
480
367
|
|
|
481
368
|
|
|
482
|
-
def
|
|
483
|
-
|
|
484
|
-
logger=False, use_wandb=False, model_avg_window=3, **model_kwargs):
|
|
485
|
-
# Disable lightning logs only for this function
|
|
486
|
-
original_lightning_level = logging.getLogger('lightning').level
|
|
487
|
-
original_pytorch_lightning_level = logging.getLogger("pytorch_lightning").level
|
|
488
|
-
|
|
489
|
-
logging.getLogger('lightning').setLevel(0)
|
|
490
|
-
logging.getLogger("pytorch_lightning").setLevel(0)
|
|
491
|
-
|
|
369
|
+
def suppress_lightning_logs(func):
|
|
370
|
+
"""Temporarily suppresses noisy PyTorch Lightning startup logs."""
|
|
492
371
|
class IgnorePLFilter(logging.Filter):
|
|
493
372
|
def filter(self, record):
|
|
494
373
|
keywords = ['available:', 'CUDA', 'LOCAL_RANK:']
|
|
495
374
|
return not any(keyword in record.getMessage() for keyword in keywords)
|
|
496
|
-
|
|
497
|
-
pl_filter = IgnorePLFilter()
|
|
498
|
-
logger_rank_zero = logging.getLogger('pytorch_lightning.utilities.rank_zero')
|
|
499
|
-
logger_cuda = logging.getLogger('pytorch_lightning.accelerators.cuda')
|
|
500
|
-
logger_rank_zero.addFilter(pl_filter)
|
|
501
|
-
logger_cuda.addFilter(pl_filter)
|
|
502
|
-
|
|
503
|
-
try:
|
|
504
|
-
compress_feats = compress_feats.float().cpu()
|
|
505
|
-
uncompress_feats = uncompress_feats.float().cpu()
|
|
506
|
-
l, c_in = compress_feats.shape
|
|
507
|
-
c_out = uncompress_feats.shape[1]
|
|
508
375
|
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
logger_rank_zero.removeFilter(pl_filter)
|
|
548
|
-
logger_cuda.removeFilter(pl_filter)
|
|
376
|
+
@wraps(func)
|
|
377
|
+
def wrapper(*args, **kwargs):
|
|
378
|
+
original_lightning_level = logging.getLogger('lightning').level
|
|
379
|
+
original_pytorch_lightning_level = logging.getLogger("pytorch_lightning").level
|
|
380
|
+
logging.getLogger('lightning').setLevel(0)
|
|
381
|
+
logging.getLogger("pytorch_lightning").setLevel(0)
|
|
382
|
+
pl_filter = IgnorePLFilter()
|
|
383
|
+
logger_rank_zero = logging.getLogger('pytorch_lightning.utilities.rank_zero')
|
|
384
|
+
logger_cuda = logging.getLogger('pytorch_lightning.accelerators.cuda')
|
|
385
|
+
logger_rank_zero.addFilter(pl_filter)
|
|
386
|
+
logger_cuda.addFilter(pl_filter)
|
|
387
|
+
try:
|
|
388
|
+
return func(*args, **kwargs)
|
|
389
|
+
finally:
|
|
390
|
+
logging.getLogger('lightning').setLevel(original_lightning_level)
|
|
391
|
+
logging.getLogger("pytorch_lightning").setLevel(original_pytorch_lightning_level)
|
|
392
|
+
logger_rank_zero.removeFilter(pl_filter)
|
|
393
|
+
logger_cuda.removeFilter(pl_filter)
|
|
394
|
+
|
|
395
|
+
return wrapper
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
@suppress_lightning_logs
|
|
399
|
+
def train_mspace_model(compress_feats, uncompress_feats, training_steps=1000, decoder_training_steps=1000,
|
|
400
|
+
batch_size=1000, return_trainer=False, progress_bar=False,
|
|
401
|
+
logger=False, use_wandb=False, **model_kwargs):
|
|
402
|
+
# check args
|
|
403
|
+
valid_keys = ['z_dim', 'flag_loss', 'flag_loss_mode', 'recon_loss', 'repulsion_loss', 'zero_center_loss', 'n_eig_list', 'n_layer', 'latent_dim', 'lr']
|
|
404
|
+
for key in model_kwargs.keys():
|
|
405
|
+
if key not in valid_keys:
|
|
406
|
+
raise ValueError(f"Invalid argument key: {key}. Valid keys: {valid_keys}")
|
|
407
|
+
|
|
408
|
+
compress_feats = compress_feats.float().cpu()
|
|
409
|
+
uncompress_feats = uncompress_feats.float().cpu()
|
|
410
|
+
l, c_in = compress_feats.shape
|
|
411
|
+
c_out = uncompress_feats.shape[1]
|
|
412
|
+
|
|
413
|
+
model = TrainEncoder(c_in, c_out, training_steps=training_steps, progress_bar=progress_bar, **model_kwargs)
|
|
549
414
|
|
|
550
|
-
|
|
415
|
+
dataset = TensorDataset(compress_feats, uncompress_feats)
|
|
416
|
+
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
|
|
417
|
+
|
|
418
|
+
trainer_args = {
|
|
419
|
+
'accelerator': auto_device(),
|
|
420
|
+
'devices': 1,
|
|
421
|
+
'enable_checkpointing': False,
|
|
422
|
+
'enable_progress_bar': False,
|
|
423
|
+
'enable_model_summary': False,
|
|
424
|
+
}
|
|
551
425
|
|
|
426
|
+
if use_wandb and not logger:
|
|
427
|
+
logger = pl.loggers.WandbLogger(project='mspace', name='mspace')
|
|
552
428
|
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
retry_count = 0
|
|
557
|
-
original_n_eig = kwargs.get('n_eig', 8)
|
|
429
|
+
# train the autoencoder jointly
|
|
430
|
+
trainer = pl.Trainer(max_steps=training_steps, logger=logger, **trainer_args)
|
|
431
|
+
trainer.fit(model, dataloader)
|
|
558
432
|
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
else:
|
|
575
|
-
kwargs['n_eig'] = n_eig
|
|
576
|
-
warnings.warn(f"Error in training mspace model: {e}. Trying with n_eig={n_eig}... Retrying ({retry_count}/{max_retries})...")
|
|
577
|
-
|
|
578
|
-
# If we've exhausted all retries, raise the exception
|
|
579
|
-
if retry_count >= max_retries:
|
|
580
|
-
raise Exception(f"Failed to train mspace model after {max_retries} retries. Last error: {e}")
|
|
581
|
-
|
|
582
|
-
torch.cuda.empty_cache()
|
|
433
|
+
mspace_ae = model.mspace_ae
|
|
434
|
+
|
|
435
|
+
if decoder_training_steps > 0:
|
|
436
|
+
# train the decoder only
|
|
437
|
+
trainer = pl.Trainer(max_steps=decoder_training_steps, logger=logger, **trainer_args)
|
|
438
|
+
model2 = TrainDecoder(mspace_ae, progress_bar=progress_bar, training_steps=decoder_training_steps, **model_kwargs)
|
|
439
|
+
trainer.fit(model2, dataloader)
|
|
440
|
+
model.mspace_ae = model2.mspace_ae
|
|
441
|
+
|
|
442
|
+
if return_trainer:
|
|
443
|
+
result = model, trainer
|
|
444
|
+
else:
|
|
445
|
+
result = model
|
|
446
|
+
|
|
447
|
+
return result
|
|
583
448
|
|
|
584
449
|
def mspace_viz_transform(X, return_model=False, **kwargs):
|
|
585
450
|
X = X.float().cpu()
|
|
586
|
-
# model, trainer = try_train_mspace(X, X, return_trainer=True, **kwargs)
|
|
587
451
|
model, trainer = train_mspace_model(X, X, return_trainer=True, **kwargs)
|
|
588
452
|
|
|
589
453
|
batch_size = kwargs.get('batch_size', 1000)
|
|
@@ -593,3 +457,4 @@ def mspace_viz_transform(X, return_model=False, **kwargs):
|
|
|
593
457
|
if return_model:
|
|
594
458
|
return compressed, model.mspace_ae
|
|
595
459
|
return compressed
|
|
460
|
+
|
|
@@ -203,7 +203,7 @@ def nystrom_propagate(
|
|
|
203
203
|
Returns:
|
|
204
204
|
torch.Tensor: output propagated by nearest neighbors, shape (N, D)
|
|
205
205
|
"""
|
|
206
|
-
if X.shape[0] <= SMALL_SCALE_THRESHOLD and
|
|
206
|
+
if X.shape[0] <= SMALL_SCALE_THRESHOLD and nystrom_X.shape == X.shape and torch.allclose(nystrom_X.to(X.device), X, atol=1e-6):
|
|
207
207
|
# skip propagation if nystrom_out is the same as X, for small scale graph that don't need nystrom approximation
|
|
208
208
|
if return_indices:
|
|
209
209
|
return nystrom_out, np.arange(X.shape[0])
|
|
@@ -72,11 +72,11 @@ class NcutPredictor:
|
|
|
72
72
|
hierarchy_assign.append(self.get_n_segments(n_eig))
|
|
73
73
|
self._hierarchy_assign = hierarchy_assign
|
|
74
74
|
|
|
75
|
-
def get_n_segments(self, n_cluster: int
|
|
75
|
+
def get_n_segments(self, n_cluster: int) -> torch.Tensor:
|
|
76
76
|
self.__check_initialized()
|
|
77
|
-
eigvecs = self.get_n_eigvecs(
|
|
77
|
+
eigvecs = self.get_n_eigvecs(n_cluster)
|
|
78
78
|
# kway_eigvec = kway_ncut(eigvecs, device=self.device, sample_idx=self._kway_sample_idx)
|
|
79
|
-
kway_eigvec = quick_kway(eigvecs, n_eig=
|
|
79
|
+
kway_eigvec = quick_kway(eigvecs, n_eig=n_cluster, n_clusters=n_cluster, device=self.device)
|
|
80
80
|
cluster_assignment = kway_eigvec.argmax(dim=1).cpu()
|
|
81
81
|
return cluster_assignment
|
|
82
82
|
|
{ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/predictor/vision_predictor.py
RENAMED
|
@@ -76,7 +76,7 @@ class NcutVisionPredictor:
|
|
|
76
76
|
torch.Tensor: Cluster assignment for the images. (b, h, w)
|
|
77
77
|
"""
|
|
78
78
|
self.__check_initialized()
|
|
79
|
-
cluster_assignment = self.predictor.get_n_segments(n_segment
|
|
79
|
+
cluster_assignment = self.predictor.get_n_segments(n_segment)
|
|
80
80
|
b, h, w = len(self._images), self._feat_hws[0], self._feat_hws[1]
|
|
81
81
|
cluster_assignment = cluster_assignment.reshape(b, h, w)
|
|
82
82
|
return cluster_assignment
|
|
@@ -114,7 +114,7 @@ def rbf_eigvec_manual_grad(
|
|
|
114
114
|
return grad_u
|
|
115
115
|
|
|
116
116
|
|
|
117
|
-
class
|
|
117
|
+
class EigvecOuterProduct(torch.autograd.Function):
|
|
118
118
|
"""
|
|
119
119
|
A (symmetric) -> {P_b}_b, where P_b = U_{S_b} U_{S_b}^T and S_b is specified by a boolean mask.
|
|
120
120
|
|
|
@@ -212,22 +212,24 @@ class MultiSpectralProjectorFromMasks(torch.autograd.Function):
|
|
|
212
212
|
grad_A_used = grad_A_used + (Bmat + Bmat.T)
|
|
213
213
|
|
|
214
214
|
grad_A = 0.5 * (grad_A_used + grad_A_used.T) if symmetrize else grad_A_used
|
|
215
|
+
# print(f"grad_A.shape: {grad_A.shape}, grad_A.norm: {grad_A.norm()}")
|
|
215
216
|
return grad_A, None, None, None
|
|
216
217
|
|
|
217
218
|
|
|
218
|
-
def
|
|
219
|
+
def eigvec_outer_product(
|
|
219
220
|
A: torch.Tensor,
|
|
220
|
-
|
|
221
|
+
eigval_masks: torch.Tensor,
|
|
221
222
|
gap_eps: float = 0.0,
|
|
222
223
|
symmetrize: bool = True,
|
|
223
224
|
):
|
|
224
225
|
"""
|
|
225
|
-
|
|
226
|
+
Computes the outer product of the eigenvectors U U^T, where U is the eigenvector matrix of A.
|
|
227
|
+
gradient of this function is stable, even if the eigenvalues are close to each other.
|
|
226
228
|
|
|
227
229
|
masks: [B,N] bool in DESCENDING eigen-order (0 = largest eigenvalue).
|
|
228
230
|
returns P: [B,N,N]
|
|
229
231
|
"""
|
|
230
|
-
return
|
|
232
|
+
return EigvecOuterProduct.apply(A, eigval_masks, gap_eps, symmetrize)
|
|
231
233
|
|
|
232
234
|
|
|
233
235
|
if __name__ == "__main__":
|
|
@@ -241,12 +243,12 @@ if __name__ == "__main__":
|
|
|
241
243
|
A1 = torch.randn(N, N)
|
|
242
244
|
A1 = 0.5 * (A1 + A1.T)
|
|
243
245
|
A1.requires_grad_(True)
|
|
244
|
-
P1 =
|
|
246
|
+
P1 = eigvec_outer_product(A1, masks)
|
|
245
247
|
|
|
246
248
|
A2 = torch.randn(N, N)
|
|
247
249
|
A2 = 0.5 * (A2 + A2.T)
|
|
248
250
|
A2.requires_grad_(True)
|
|
249
|
-
P2 =
|
|
251
|
+
P2 = eigvec_outer_product(A2, masks)
|
|
250
252
|
|
|
251
253
|
loss = torch.norm(P1 - P2, p=2, dim=(0, 1)).sum()
|
|
252
254
|
loss.backward()
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/predictor/dino/hires_dino.py
RENAMED
|
File without changes
|
{ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/predictor/dino/lowres_dino.py
RENAMED
|
File without changes
|
|
File without changes
|
{ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/predictor/dino/transform.py
RENAMED
|
File without changes
|
{ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/predictor/dino_predictor.py
RENAMED
|
File without changes
|
{ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/predictor/jafar_predictor.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch.egg-info/dependency_links.txt
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|