ncut-pytorch 3.0.0.dev6__tar.gz → 3.0.0.dev8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/PKG-INFO +1 -1
  2. {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/color/coloring.py +8 -8
  3. {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/color/mspace.py +132 -267
  4. {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/ncuts/ncut_nystrom.py +1 -1
  5. {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/predictor/predictor.py +3 -3
  6. {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/predictor/vision_predictor.py +1 -1
  7. {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/utils/grad.py +9 -7
  8. {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch.egg-info/PKG-INFO +1 -1
  9. {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/pyproject.toml +1 -1
  10. {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/LICENSE +0 -0
  11. {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/README.md +0 -0
  12. {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/__init__.py +0 -0
  13. {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/color/__init__.py +0 -0
  14. {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/color/mspace_nopl.py +0 -0
  15. {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/ncut.py +0 -0
  16. {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/ncuts/__init__.py +0 -0
  17. {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/ncuts/ncut_click.py +0 -0
  18. {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/ncuts/ncut_kway.py +0 -0
  19. {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/predictor/__init__.py +0 -0
  20. {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/predictor/dino/__init__.py +0 -0
  21. {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/predictor/dino/api.py +0 -0
  22. {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/predictor/dino/dinov3.py +0 -0
  23. {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/predictor/dino/hires_dino.py +0 -0
  24. {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/predictor/dino/lowres_dino.py +0 -0
  25. {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/predictor/dino/patch.py +0 -0
  26. {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/predictor/dino/transform.py +0 -0
  27. {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/predictor/dino_predictor.py +0 -0
  28. {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/predictor/jafar_predictor.py +0 -0
  29. {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/utils/__init__.py +0 -0
  30. {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/utils/device.py +0 -0
  31. {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/utils/math.py +0 -0
  32. {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/utils/sample.py +0 -0
  33. {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/utils/sigma.py +0 -0
  34. {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch/utils/torch_mod.py +0 -0
  35. {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch.egg-info/SOURCES.txt +0 -0
  36. {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch.egg-info/dependency_links.txt +0 -0
  37. {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch.egg-info/requires.txt +0 -0
  38. {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/ncut_pytorch.egg-info/top_level.txt +0 -0
  39. {ncut_pytorch-3.0.0.dev6 → ncut_pytorch-3.0.0.dev8}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ncut_pytorch
3
- Version: 3.0.0.dev6
3
+ Version: 3.0.0.dev8
4
4
  Summary: Normalized Cut and Spectral Embedding
5
5
  Author-email: Huzheng Yang <huze.yann@gmail.com>
6
6
  License-Expression: MIT
@@ -1,7 +1,7 @@
1
1
  __all__ = ["mspace_color", "tsne_color", "umap_color", "umap_sphere_color", "rotate_rgb_cube", "convert_to_lab_color"]
2
2
 
3
3
  import warnings
4
- from typing import Any, Callable, Dict, Literal, Tuple, Optional
4
+ from typing import Any, Callable, Dict, Literal, Tuple, Optional, List
5
5
 
6
6
  from numba.core.types import none
7
7
  import numpy as np
@@ -19,7 +19,7 @@ def _identity(X: torch.Tensor) -> torch.Tensor:
19
19
  def mspace_color(
20
20
  X: torch.Tensor,
21
21
  q: float = 0.95,
22
- n_eig: Optional[int] = 8,
22
+ n_eig_list: Optional[List[int]] = [4, 16, 64],
23
23
  n_dim: int = 3,
24
24
  training_steps: int = 1000,
25
25
  progress_bar: bool = False,
@@ -34,15 +34,15 @@ def mspace_color(
34
34
 
35
35
  low_dim_embedding = mspace_viz_transform(
36
36
  X=X,
37
- n_eig=n_eig,
38
- mood_dim=n_dim,
37
+ n_eig_list=n_eig_list,
38
+ z_dim=n_dim,
39
39
  training_steps=training_steps,
40
- decoder_training_steps=0,
41
40
  progress_bar=progress_bar,
42
- eigvec_loss=1.0,
43
- recon_loss=1.0,
41
+ flag_loss_mode='z',
42
+ flag_loss=1.0,
43
+ recon_loss=.001,
44
44
  zero_center_loss=0.001,
45
- repulsion_loss=0.01,
45
+ repulsion_loss=0.001,
46
46
  **kwargs)
47
47
 
48
48
  rgb = rgb_from_nd_colormap(low_dim_embedding, q=q)
@@ -2,9 +2,10 @@ __all__ = ["train_mspace_model", "mspace_viz_transform"]
2
2
 
3
3
  import logging
4
4
  from collections import defaultdict
5
- from functools import partial
5
+ from functools import partial, wraps
6
6
  import warnings
7
7
  import pytorch_lightning as pl
8
+ from typing import List
8
9
 
9
10
  import numpy as np
10
11
  import torch
@@ -15,58 +16,21 @@ from tqdm import tqdm
15
16
 
16
17
  from ncut_pytorch.utils.device import auto_device
17
18
  from ncut_pytorch import ncut_fn
18
- from ncut_pytorch.utils.math import grad_safe_eig_solve, normalize_affinity, rbf_affinity, svd_lowrank, cosine_affinity
19
-
20
-
21
- def _flag_ncut_loss(eigvec_gt, eigvec_hat, n_eig, weight):
22
- _eigvec_gt = eigvec_gt[:, :n_eig]
23
- _eigvec_hat = eigvec_hat[:, :n_eig]
24
- left = _eigvec_gt @ _eigvec_gt.T
25
- right = _eigvec_hat @ _eigvec_hat.T
26
- left = left * weight[:, None] * weight[None, :]
27
- right = right * weight[:, None] * weight[None, :]
28
- loss = F.l1_loss(left, right)
29
- return loss
30
-
31
- def flag_space_loss(eigvec_gt, eigvec_hat, n_eig, start=2, step_mult=2, weight=None):
32
- if torch.all(eigvec_gt == 0) or torch.all(eigvec_hat == 0):
33
- return torch.tensor(0, device=eigvec_gt.device)
34
-
35
- if weight is None:
36
- weight = torch.ones(eigvec_gt.shape[0], device=eigvec_gt.device, dtype=eigvec_gt.dtype)
37
-
38
- loss = 0
39
- n_eig = start // step_mult
40
- while True:
41
- n_eig *= step_mult
42
- loss += _flag_ncut_loss(eigvec_gt, eigvec_hat, n_eig, weight)
43
- if n_eig > eigvec_gt.shape[1] or n_eig > eigvec_hat.shape[1]:
44
- break
45
- return loss
46
-
47
- def filter_closeby_eigval(eigvec, eigval, threshold=1e-3):
48
- # filter out eigvals that are too close to each other
49
- # so the gradient is more stable
50
- eigval_diff = torch.diff(eigval).abs()
51
- keep_idx = torch.where(eigval_diff > threshold)[0]
52
- return eigvec[:, keep_idx], eigval[keep_idx]
53
-
54
- def ncut_wrapper(features, n_eig, sigma=None):
55
- # TODO: gradient is not stable.
56
-
57
- # features.requires_grad_(True)
58
- sigma = sigma or features.std(0).sum().item()
59
- eigvec, eigval = ncut_fn(features, n_eig, sigma=sigma)
19
+ from ncut_pytorch.utils.math import normalize_affinity, rbf_affinity, cosine_affinity
20
+ from ncut_pytorch.utils.grad import eigvec_outer_product
21
+ from ncut_pytorch.utils.sigma import find_sigma_by_degree
22
+
23
+
24
+ def get_eigvec_outer_product(features: torch.Tensor, n_eig_list: List[int]) -> torch.Tensor:
25
+ sigma = find_sigma_by_degree(features)
60
26
  W = rbf_affinity(features, sigma=sigma)
61
- # W = cosine_affinity(features, sigma=1.0)
62
27
  A = normalize_affinity(W)
63
- eigvec, eigval, _ = svd_lowrank(A, n_eig)
64
- # eigval, eigvec = torch.linalg.eigh(A)
65
- # eigvec = eigvec.flip(dims=[1])
66
- # eigval = eigval.flip(dims=[0])
67
- # eigvec = eigvec[:, :n_eig]
68
- # eigval = eigval[:n_eig]
69
- return eigvec, eigval
28
+ eigval_masks = torch.zeros(len(n_eig_list), A.shape[0], dtype=torch.bool)
29
+ for i, n_eig in enumerate(n_eig_list):
30
+ eigval_masks[i, :n_eig] = True
31
+ P = eigvec_outer_product(A, eigval_masks)
32
+ return P
33
+
70
34
 
71
35
 
72
36
  class MovingMinMax(nn.Module):
@@ -112,12 +76,12 @@ class MLP(nn.Module):
112
76
 
113
77
 
114
78
  class MspaceAutoEncoder(nn.Module):
115
- def __init__(self, in_dim, out_dim, mood_dim, n_layer=4, latent_dim=256, encoder_activation='gelu', decoder_activation='gelu', final_activation='identity'):
79
+ def __init__(self, in_dim, out_dim, z_dim, n_layer=4, latent_dim=256, encoder_activation='gelu', decoder_activation='gelu', final_activation='identity'):
116
80
  super().__init__()
117
- self.encoder = MLP(in_dim, mood_dim, n_layer, latent_dim, encoder_activation, final_activation)
81
+ self.encoder = MLP(in_dim, z_dim, n_layer, latent_dim, encoder_activation, final_activation)
118
82
  self.decoder = nn.Sequential(
119
- MovingMinMax(mood_dim),
120
- MLP(mood_dim, out_dim, n_layer, latent_dim, decoder_activation)
83
+ MovingMinMax(z_dim),
84
+ MLP(z_dim, out_dim, n_layer, latent_dim, decoder_activation)
121
85
  )
122
86
 
123
87
 
@@ -260,34 +224,29 @@ class TrainDecoder(pl.LightningModule):
260
224
  class TrainEncoder(pl.LightningModule):
261
225
  N_ITER_PER_STEP = 10
262
226
 
263
- def __init__(self, in_dim, out_dim, mood_dim=2, n_eig=8, n_elbow=3,
227
+ def __init__(self, in_dim, out_dim, z_dim=2, n_eig_list=[4, 16, 64],
264
228
  n_layer=4, latent_dim=256,
265
- eigvec_loss=100, recon_loss=0,
266
- riemann_curvature_loss=0., axis_align_loss=0,
267
- repulsion_loss=0.1, attraction_loss=0.,
268
- boundary_loss=0., zero_center_loss=0.01,
269
- lr=0.001, progress_bar=True, training_steps=500,
229
+ flag_loss_mode='z_eigvec', # 'z_eigvec' or 'z'
230
+ flag_loss=1.0, recon_loss=1e-3,
231
+ repulsion_loss=1e-3, zero_center_loss=1e-3,
232
+ lr=1e-3, progress_bar=True, training_steps=1000,
270
233
  encoder_activation='gelu', decoder_activation='gelu',
271
234
  final_activation='identity',
272
- log_grad_norm=False,
235
+ log_grad_norm=True,
273
236
  **kwargs):
274
237
  super().__init__()
275
238
 
276
- self.mspace_ae = MspaceAutoEncoder(in_dim, out_dim, mood_dim, n_layer, latent_dim, encoder_activation, decoder_activation, final_activation)
239
+ self.mspace_ae = MspaceAutoEncoder(in_dim, out_dim, z_dim, n_layer, latent_dim, encoder_activation, decoder_activation, final_activation)
277
240
 
278
241
  self.loss_history = defaultdict(list)
279
242
 
280
- self.eigvec_loss = eigvec_loss
243
+ self.flag_loss_mode = flag_loss_mode
244
+ self.flag_loss = flag_loss
281
245
  self.recon_loss = recon_loss
282
- self.riemann_curvature_loss = riemann_curvature_loss
283
- self.axis_align_loss = axis_align_loss
284
246
  self.repulsion_loss = repulsion_loss
285
- self.attraction_loss = attraction_loss
286
- self.boundary_loss = boundary_loss
287
247
  self.zero_center_loss = zero_center_loss
288
248
  self.lr = lr
289
- self.n_eig = n_eig
290
- self.n_elbow = n_elbow
249
+ self.n_eig_list = n_eig_list
291
250
 
292
251
  self.progress_bar = progress_bar
293
252
  self.training_steps = training_steps
@@ -303,11 +262,11 @@ class TrainEncoder(pl.LightningModule):
303
262
  return self.mspace_ae.encoder(x)
304
263
 
305
264
 
306
- def _log_loss(self, loss, name, log_grad_norm=False):
307
- if self.logger is None:
308
- return
309
- self.logger.log_metrics({f'loss/{name}': loss.item()}, step=self.global_step)
310
- self.loss_history[f"loss/{name}"].append(loss.item())
265
+ def _log_loss(self, loss, name, log_grad_norm=False, iteration=0):
266
+ # if self.logger is None:
267
+ # return
268
+ # self.logger.log_metrics({f'loss/{name}': loss.item()}, step=self.global_step)
269
+ # self.loss_history[f"loss/{name}"].append(loss.item())
311
270
  if log_grad_norm:
312
271
  grad_norm = 0
313
272
  self.manual_backward(loss, retain_graph=True)
@@ -315,8 +274,10 @@ class TrainEncoder(pl.LightningModule):
315
274
  if param.grad is not None:
316
275
  grad_norm += param.grad.norm().item() ** 2
317
276
  grad_norm = grad_norm ** 0.5
318
- self.logger.log_metrics({f'grad/{name}': grad_norm}, step=self.global_step)
319
- self.loss_history[f"grad/{name}"].append(grad_norm)
277
+ # self.logger.log_metrics({f'grad/{name}': grad_norm}, step=self.global_step)
278
+ # self.loss_history[f"grad/{name}"].append(grad_norm)
279
+ if iteration % 10 == 0:
280
+ print(f"loss/{name} grad_norm: {grad_norm:.2e}, iteration: {iteration}")
320
281
  self.optimizers().zero_grad()
321
282
 
322
283
  def training_step(self, batch):
@@ -328,12 +289,10 @@ class TrainEncoder(pl.LightningModule):
328
289
 
329
290
  input_feats, output_feats = batch
330
291
 
331
- stored_eigvec_gt = {}
332
- # with torch.no_grad():
292
+ P_gt = {}
333
293
  # Compute eigvec_gt only once for each set of iterations
334
- if self.eigvec_loss != 0:
335
- eigvec_gt, eigval_gt = ncut_wrapper(input_feats, self.n_eig, sigma=input_feats.std(0).sum().item())
336
- stored_eigvec_gt = eigvec_gt
294
+ if self.flag_loss != 0:
295
+ P_gt = get_eigvec_outer_product(input_feats, self.n_eig_list)
337
296
 
338
297
 
339
298
  # Run the same batch 10 times, updating parameters after each iteration
@@ -344,28 +303,36 @@ class TrainEncoder(pl.LightningModule):
344
303
  log_grad_norm = iteration == 0 and self.log_grad_norm
345
304
 
346
305
  total_loss = 0
347
- if self.eigvec_loss != 0:
348
- assert feats_compressed.requires_grad, "feats_compressed must be a tensor with requires_grad=True"
349
- eigvec_hat, eigval_hat = ncut_wrapper(feats_compressed, self.n_eig, sigma=feats_compressed.std(0).sum().item())
350
- assert eigvec_hat.requires_grad, "eigvec_hat must be a tensor with requires_grad=True"
351
- eigvec_loss = flag_space_loss(stored_eigvec_gt, eigvec_hat, n_eig=self.n_eig, weight=None)
352
- eigvec_loss = eigvec_loss * self.eigvec_loss
353
- total_loss += eigvec_loss
354
- self._log_loss(eigvec_loss, "eigvec", log_grad_norm=log_grad_norm)
306
+ if self.flag_loss != 0:
307
+ if self.flag_loss_mode == 'z_eigvec':
308
+ P_hat = get_eigvec_outer_product(feats_compressed, self.n_eig_list)
309
+ flag_loss = F.mse_loss(P_gt, P_hat)
310
+ elif self.flag_loss_mode == 'z':
311
+ P_hat = feats_compressed @ feats_compressed.T
312
+ flag_loss = F.mse_loss(P_gt.mean(0), P_hat)
313
+ flag_loss = flag_loss * self.flag_loss
314
+ total_loss += flag_loss
315
+ self._log_loss(flag_loss, "flag_loss", log_grad_norm=log_grad_norm, iteration=self.trainer.global_step)
355
316
 
356
317
  if self.recon_loss > 0:
357
318
  feats_uncompressed = self.mspace_ae.decoder(feats_compressed)
358
- recon_loss = F.smooth_l1_loss(output_feats, feats_uncompressed, beta=0.1)
319
+ recon_loss = F.mse_loss(output_feats, feats_uncompressed)
359
320
  recon_loss = recon_loss * self.recon_loss
360
321
  total_loss += recon_loss
361
- self._log_loss(recon_loss, "recon", log_grad_norm=log_grad_norm)
322
+ self._log_loss(recon_loss, "recon_loss", log_grad_norm=log_grad_norm, iteration=self.trainer.global_step)
362
323
 
363
324
  if self.zero_center_loss > 0:
364
325
  # zero_center_loss = feats_compressed.abs().mean()
365
326
  zero_center_loss = (feats_compressed ** 2).mean()
366
327
  zero_center_loss = zero_center_loss * self.zero_center_loss
367
328
  total_loss += zero_center_loss
368
- self._log_loss(zero_center_loss, "zero_center", log_grad_norm=log_grad_norm)
329
+ self._log_loss(zero_center_loss, "zero_center_loss", log_grad_norm=log_grad_norm, iteration=self.trainer.global_step)
330
+
331
+ if self.repulsion_loss > 0:
332
+ repulsion_loss = compute_repulsion_loss(feats_compressed)
333
+ repulsion_loss = repulsion_loss * self.repulsion_loss
334
+ total_loss += repulsion_loss
335
+ self._log_loss(repulsion_loss, "repulsion_loss", log_grad_norm=log_grad_norm, iteration=self.trainer.global_step)
369
336
 
370
337
  # Log the loss for this iteration
371
338
  if self.logger is not None:
@@ -387,86 +354,6 @@ class TrainEncoder(pl.LightningModule):
387
354
  return optimizer
388
355
 
389
356
 
390
- # Moving Average Callback
391
- class BestModelsAvgCallback(pl.Callback):
392
- """
393
- Callback to store the top models with the lowest total loss and average them at the end.
394
- """
395
- def __init__(self, top_k=10):
396
- super().__init__()
397
- self.top_k = top_k
398
- self.best_models = [] # List of tuples (loss, params)
399
- self.best_loss = float('inf')
400
-
401
- def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
402
- """Store the model if it has a lower loss than the current best models."""
403
-
404
- # skip if the global step is on the frist 80% of the training steps
405
- if trainer.global_step < trainer.max_steps * 0.8:
406
- return
407
-
408
- # Get current model parameters
409
- current_params = {}
410
- for name, param in pl_module.named_parameters():
411
- current_params[name] = param.data.clone()
412
-
413
- # Get the current total loss from the outputs
414
- # The outputs from training_step is the total loss
415
- current_loss = outputs['loss'].item()
416
-
417
- # If we have fewer than top_k models or this is a better model
418
- if len(self.best_models) < self.top_k or current_loss < self.best_loss:
419
- # Add to best models
420
- self.best_models.append((current_loss, current_params))
421
-
422
- # Sort by loss (ascending)
423
- self.best_models.sort(key=lambda x: x[0])
424
-
425
- # Keep only top_k models
426
- if len(self.best_models) > self.top_k:
427
- self.best_models = self.best_models[:self.top_k]
428
-
429
- # Update best loss
430
- self.best_loss = self.best_models[0][0]
431
-
432
- def on_train_end(self, trainer, pl_module):
433
- """Average the best model parameters and apply them to the model at the end of training."""
434
- if self.best_models:
435
- # Initialize averaged parameters
436
- avg_params = {}
437
-
438
- # Get parameter names from the first model
439
- param_names = self.best_models[0][1].keys()
440
-
441
- # Initialize with zeros of the same shape
442
- for name in param_names:
443
- avg_params[name] = torch.zeros_like(self.best_models[0][1][name])
444
-
445
- # Sum all parameters from best models
446
- for _, params in self.best_models:
447
- for name in param_names:
448
- avg_params[name] += params[name]
449
-
450
- # Divide by the number of models to get the average
451
- for name in param_names:
452
- avg_params[name] /= len(self.best_models)
453
-
454
- # Apply averaged parameters to the model
455
- for name, param in pl_module.named_parameters():
456
- if name in avg_params:
457
- param.data.copy_(avg_params[name])
458
-
459
- # Log the best loss and average loss
460
- if pl_module.logger is not None:
461
- pl_module.logger.log_metrics({
462
- 'best_loss': self.best_loss,
463
- 'avg_loss': sum(loss for loss, _ in self.best_models) / len(self.best_models)
464
- }, step=trainer.global_step)
465
-
466
- self.best_models = []
467
- self.best_loss = float('inf')
468
-
469
-
470
357
  def compute_repulsion_loss(
471
358
  points: torch.Tensor, # [N, D]
472
359
  ) -> torch.Tensor:
@@ -475,115 +362,92 @@ def compute_repulsion_loss(
475
362
  mask = torch.eye(points.shape[0], device=points.device).bool()
476
363
  dist_matrix = dist_matrix + mask * 1e10
477
364
  nearest_dists, _ = torch.min(dist_matrix, dim=1)
478
- repulsion = 1.0 / (nearest_dists + 0.01)
365
+ repulsion = 1.0 / (nearest_dists + 1)
479
366
  return torch.mean(repulsion)
480
367
 
481
368
 
482
- def train_mspace_model(compress_feats, uncompress_feats, training_steps=500, decoder_training_steps=1000,
483
- batch_size=1000, return_trainer=False, progress_bar=True,
484
- logger=False, use_wandb=False, model_avg_window=3, **model_kwargs):
485
- # Disable lightning logs only for this function
486
- original_lightning_level = logging.getLogger('lightning').level
487
- original_pytorch_lightning_level = logging.getLogger("pytorch_lightning").level
488
-
489
- logging.getLogger('lightning').setLevel(0)
490
- logging.getLogger("pytorch_lightning").setLevel(0)
491
-
369
+ def suppress_lightning_logs(func):
370
+ """Temporarily suppresses noisy PyTorch Lightning startup logs."""
492
371
  class IgnorePLFilter(logging.Filter):
493
372
  def filter(self, record):
494
373
  keywords = ['available:', 'CUDA', 'LOCAL_RANK:']
495
374
  return not any(keyword in record.getMessage() for keyword in keywords)
496
-
497
- pl_filter = IgnorePLFilter()
498
- logger_rank_zero = logging.getLogger('pytorch_lightning.utilities.rank_zero')
499
- logger_cuda = logging.getLogger('pytorch_lightning.accelerators.cuda')
500
- logger_rank_zero.addFilter(pl_filter)
501
- logger_cuda.addFilter(pl_filter)
502
-
503
- try:
504
- compress_feats = compress_feats.float().cpu()
505
- uncompress_feats = uncompress_feats.float().cpu()
506
- l, c_in = compress_feats.shape
507
- c_out = uncompress_feats.shape[1]
508
375
 
509
- model = TrainEncoder(c_in, c_out, training_steps=training_steps, progress_bar=progress_bar, **model_kwargs)
510
-
511
- dataset = TensorDataset(compress_feats, uncompress_feats)
512
- dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
513
-
514
- trainer_args = {
515
- 'accelerator': auto_device(),
516
- 'devices': 1,
517
- 'enable_checkpointing': False,
518
- 'enable_progress_bar': False,
519
- 'enable_model_summary': False,
520
- 'callbacks': [BestModelsAvgCallback(top_k=model_avg_window)],
521
- }
522
-
523
- if use_wandb and not logger:
524
- logger = pl.loggers.WandbLogger(project='mspace', name='mspace')
525
-
526
- # train the autoencoder jointly
527
- trainer = pl.Trainer(max_steps=training_steps, logger=logger, **trainer_args)
528
- trainer.fit(model, dataloader)
529
-
530
- mspace_ae = model.mspace_ae
531
-
532
- if decoder_training_steps > 0:
533
- # train the decoder only
534
- trainer = pl.Trainer(max_steps=decoder_training_steps, logger=logger, **trainer_args)
535
- model2 = TrainDecoder(mspace_ae, progress_bar=progress_bar, training_steps=decoder_training_steps, **model_kwargs)
536
- trainer.fit(model2, dataloader)
537
- model.mspace_ae = model2.mspace_ae
538
-
539
- if return_trainer:
540
- result = model, trainer
541
- else:
542
- result = model
543
- finally:
544
- # Restore original logging configuration
545
- logging.getLogger('lightning').setLevel(original_lightning_level)
546
- logging.getLogger("pytorch_lightning").setLevel(original_pytorch_lightning_level)
547
- logger_rank_zero.removeFilter(pl_filter)
548
- logger_cuda.removeFilter(pl_filter)
376
+ @wraps(func)
377
+ def wrapper(*args, **kwargs):
378
+ original_lightning_level = logging.getLogger('lightning').level
379
+ original_pytorch_lightning_level = logging.getLogger("pytorch_lightning").level
380
+ logging.getLogger('lightning').setLevel(0)
381
+ logging.getLogger("pytorch_lightning").setLevel(0)
382
+ pl_filter = IgnorePLFilter()
383
+ logger_rank_zero = logging.getLogger('pytorch_lightning.utilities.rank_zero')
384
+ logger_cuda = logging.getLogger('pytorch_lightning.accelerators.cuda')
385
+ logger_rank_zero.addFilter(pl_filter)
386
+ logger_cuda.addFilter(pl_filter)
387
+ try:
388
+ return func(*args, **kwargs)
389
+ finally:
390
+ logging.getLogger('lightning').setLevel(original_lightning_level)
391
+ logging.getLogger("pytorch_lightning").setLevel(original_pytorch_lightning_level)
392
+ logger_rank_zero.removeFilter(pl_filter)
393
+ logger_cuda.removeFilter(pl_filter)
394
+
395
+ return wrapper
396
+
397
+
398
+ @suppress_lightning_logs
399
+ def train_mspace_model(compress_feats, uncompress_feats, training_steps=1000, decoder_training_steps=1000,
400
+ batch_size=1000, return_trainer=False, progress_bar=False,
401
+ logger=False, use_wandb=False, **model_kwargs):
402
+ # check args
403
+ valid_keys = ['z_dim', 'flag_loss', 'flag_loss_mode', 'recon_loss', 'repulsion_loss', 'zero_center_loss', 'n_eig_list', 'n_layer', 'latent_dim', 'lr']
404
+ for key in model_kwargs.keys():
405
+ if key not in valid_keys:
406
+ raise ValueError(f"Invalid argument key: {key}. Valid keys: {valid_keys}")
407
+
408
+ compress_feats = compress_feats.float().cpu()
409
+ uncompress_feats = uncompress_feats.float().cpu()
410
+ l, c_in = compress_feats.shape
411
+ c_out = uncompress_feats.shape[1]
412
+
413
+ model = TrainEncoder(c_in, c_out, training_steps=training_steps, progress_bar=progress_bar, **model_kwargs)
549
414
 
550
- return result
415
+ dataset = TensorDataset(compress_feats, uncompress_feats)
416
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
417
+
418
+ trainer_args = {
419
+ 'accelerator': auto_device(),
420
+ 'devices': 1,
421
+ 'enable_checkpointing': False,
422
+ 'enable_progress_bar': False,
423
+ 'enable_model_summary': False,
424
+ }
551
425
 
426
+ if use_wandb and not logger:
427
+ logger = pl.loggers.WandbLogger(project='mspace', name='mspace')
552
428
 
553
- def try_train_mspace(*args, **kwargs):
554
- # TODO: msapce training sometimes fails into nan, why?
555
- max_retries = 10
556
- retry_count = 0
557
- original_n_eig = kwargs.get('n_eig', 8)
429
+ # train the autoencoder jointly
430
+ trainer = pl.Trainer(max_steps=training_steps, logger=logger, **trainer_args)
431
+ trainer.fit(model, dataloader)
558
432
 
559
- while retry_count < max_retries:
560
- try:
561
- model, trainer = train_mspace_model(*args, **kwargs)
562
- return model, trainer
563
- except Exception as e:
564
- retry_count += 1
565
- current_n_eig = kwargs.get('n_eig', original_n_eig)
566
- n_eig = int(current_n_eig // 2)
567
-
568
- # If n_eig becomes too small, disable eigvec_loss and use minimum value
569
- if n_eig < 2:
570
- kwargs['eigvec_loss'] = 0.0
571
- kwargs['recon_loss'] = 1.0
572
- kwargs['n_eig'] = 2 # Ensure n_eig is at least 2
573
- warnings.warn(f"Error in training mspace model: {e}. Disabling eigvec_loss and using n_eig=1. Retrying ({retry_count}/{max_retries})...")
574
- else:
575
- kwargs['n_eig'] = n_eig
576
- warnings.warn(f"Error in training mspace model: {e}. Trying with n_eig={n_eig}... Retrying ({retry_count}/{max_retries})...")
577
-
578
- # If we've exhausted all retries, raise the exception
579
- if retry_count >= max_retries:
580
- raise Exception(f"Failed to train mspace model after {max_retries} retries. Last error: {e}")
581
-
582
- torch.cuda.empty_cache()
433
+ mspace_ae = model.mspace_ae
434
+
435
+ if decoder_training_steps > 0:
436
+ # train the decoder only
437
+ trainer = pl.Trainer(max_steps=decoder_training_steps, logger=logger, **trainer_args)
438
+ model2 = TrainDecoder(mspace_ae, progress_bar=progress_bar, training_steps=decoder_training_steps, **model_kwargs)
439
+ trainer.fit(model2, dataloader)
440
+ model.mspace_ae = model2.mspace_ae
441
+
442
+ if return_trainer:
443
+ result = model, trainer
444
+ else:
445
+ result = model
446
+
447
+ return result
583
448
 
584
449
  def mspace_viz_transform(X, return_model=False, **kwargs):
585
450
  X = X.float().cpu()
586
- # model, trainer = try_train_mspace(X, X, return_trainer=True, **kwargs)
587
451
  model, trainer = train_mspace_model(X, X, return_trainer=True, **kwargs)
588
452
 
589
453
  batch_size = kwargs.get('batch_size', 1000)
@@ -593,3 +457,4 @@ def mspace_viz_transform(X, return_model=False, **kwargs):
593
457
  if return_model:
594
458
  return compressed, model.mspace_ae
595
459
  return compressed
460
+
@@ -203,7 +203,7 @@ def nystrom_propagate(
203
203
  Returns:
204
204
  torch.Tensor: output propagated by nearest neighbors, shape (N, D)
205
205
  """
206
- if X.shape[0] <= SMALL_SCALE_THRESHOLD and nystrom_out.shape == X.shape and torch.allclose(nystrom_X.to(X.device), X, atol=1e-6):
206
+ if X.shape[0] <= SMALL_SCALE_THRESHOLD and nystrom_X.shape == X.shape and torch.allclose(nystrom_X.to(X.device), X, atol=1e-6):
207
207
  # skip propagation if nystrom_out is the same as X, for small scale graph that don't need nystrom approximation
208
208
  if return_indices:
209
209
  return nystrom_out, np.arange(X.shape[0])
@@ -72,11 +72,11 @@ class NcutPredictor:
72
72
  hierarchy_assign.append(self.get_n_segments(n_eig))
73
73
  self._hierarchy_assign = hierarchy_assign
74
74
 
75
- def get_n_segments(self, n_cluster: int, n_eig: int = 10) -> torch.Tensor:
75
+ def get_n_segments(self, n_cluster: int) -> torch.Tensor:
76
76
  self.__check_initialized()
77
- eigvecs = self.get_n_eigvecs(n_eig)
77
+ eigvecs = self.get_n_eigvecs(n_cluster)
78
78
  # kway_eigvec = kway_ncut(eigvecs, device=self.device, sample_idx=self._kway_sample_idx)
79
- kway_eigvec = quick_kway(eigvecs, n_eig=n_eig, n_clusters=n_cluster, device=self.device)
79
+ kway_eigvec = quick_kway(eigvecs, n_eig=n_cluster, n_clusters=n_cluster, device=self.device)
80
80
  cluster_assignment = kway_eigvec.argmax(dim=1).cpu()
81
81
  return cluster_assignment
82
82
 
@@ -76,7 +76,7 @@ class NcutVisionPredictor:
76
76
  torch.Tensor: Cluster assignment for the images. (b, h, w)
77
77
  """
78
78
  self.__check_initialized()
79
- cluster_assignment = self.predictor.get_n_segments(n_segment, n_eig)
79
+ cluster_assignment = self.predictor.get_n_segments(n_segment)
80
80
  b, h, w = len(self._images), self._feat_hws[0], self._feat_hws[1]
81
81
  cluster_assignment = cluster_assignment.reshape(b, h, w)
82
82
  return cluster_assignment
@@ -114,7 +114,7 @@ def rbf_eigvec_manual_grad(
114
114
  return grad_u
115
115
 
116
116
 
117
- class MultiSpectralProjectorFromMasks(torch.autograd.Function):
117
+ class EigvecOuterProduct(torch.autograd.Function):
118
118
  """
119
119
  A (symmetric) -> {P_b}_b, where P_b = U_{S_b} U_{S_b}^T and S_b is specified by a boolean mask.
120
120
 
@@ -212,22 +212,24 @@ class MultiSpectralProjectorFromMasks(torch.autograd.Function):
212
212
  grad_A_used = grad_A_used + (Bmat + Bmat.T)
213
213
 
214
214
  grad_A = 0.5 * (grad_A_used + grad_A_used.T) if symmetrize else grad_A_used
215
+ # print(f"grad_A.shape: {grad_A.shape}, grad_A.norm: {grad_A.norm()}")
215
216
  return grad_A, None, None, None
216
217
 
217
218
 
218
- def spectral_projectors_from_masks(
219
+ def eigvec_outer_product(
219
220
  A: torch.Tensor,
220
- masks: torch.Tensor,
221
+ eigval_masks: torch.Tensor,
221
222
  gap_eps: float = 0.0,
222
223
  symmetrize: bool = True,
223
224
  ):
224
225
  """
225
- Convenience wrapper.
226
+ Computes the outer product of the eigenvectors U U^T, where U is the eigenvector matrix of A.
227
+ gradient of this function is stable, even if the eigenvalues are close to each other.
226
228
 
227
229
  masks: [B,N] bool in DESCENDING eigen-order (0 = largest eigenvalue).
228
230
  returns P: [B,N,N]
229
231
  """
230
- return MultiSpectralProjectorFromMasks.apply(A, masks, gap_eps, symmetrize)
232
+ return EigvecOuterProduct.apply(A, eigval_masks, gap_eps, symmetrize)
231
233
 
232
234
 
233
235
  if __name__ == "__main__":
@@ -241,12 +243,12 @@ if __name__ == "__main__":
241
243
  A1 = torch.randn(N, N)
242
244
  A1 = 0.5 * (A1 + A1.T)
243
245
  A1.requires_grad_(True)
244
- P1 = spectral_projectors_from_masks(A1, masks)
246
+ P1 = eigvec_outer_product(A1, masks)
245
247
 
246
248
  A2 = torch.randn(N, N)
247
249
  A2 = 0.5 * (A2 + A2.T)
248
250
  A2.requires_grad_(True)
249
- P2 = spectral_projectors_from_masks(A2, masks)
251
+ P2 = eigvec_outer_product(A2, masks)
250
252
 
251
253
  loss = torch.norm(P1 - P2, p=2, dim=(0, 1)).sum()
252
254
  loss.backward()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ncut_pytorch
3
- Version: 3.0.0.dev6
3
+ Version: 3.0.0.dev8
4
4
  Summary: Normalized Cut and Spectral Embedding
5
5
  Author-email: Huzheng Yang <huze.yann@gmail.com>
6
6
  License-Expression: MIT
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "ncut_pytorch"
7
- version = "3.0.0dev6"
7
+ version = "3.0.0dev8"
8
8
  authors = [
9
9
  { name = "Huzheng Yang", email = "huze.yann@gmail.com" },
10
10
  ]