ncut-pytorch 2.3.3__tar.gz → 3.0.0.dev0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {ncut_pytorch-2.3.3 → ncut_pytorch-3.0.0.dev0}/PKG-INFO +3 -3
- {ncut_pytorch-2.3.3 → ncut_pytorch-3.0.0.dev0}/README.md +1 -1
- ncut_pytorch-3.0.0.dev0/ncut_pytorch/__init__.py +4 -0
- {ncut_pytorch-2.3.3 → ncut_pytorch-3.0.0.dev0}/ncut_pytorch/color/coloring.py +2 -2
- {ncut_pytorch-2.3.3 → ncut_pytorch-3.0.0.dev0}/ncut_pytorch/color/mspace.py +104 -96
- ncut_pytorch-3.0.0.dev0/ncut_pytorch/color/mspace_nopl.py +550 -0
- {ncut_pytorch-2.3.3 → ncut_pytorch-3.0.0.dev0}/ncut_pytorch/ncut.py +17 -17
- ncut_pytorch-3.0.0.dev0/ncut_pytorch/ncuts/__init__.py +0 -0
- ncut_pytorch-3.0.0.dev0/ncut_pytorch/ncuts/ncut_click.py +106 -0
- ncut_pytorch-3.0.0.dev0/ncut_pytorch/ncuts/ncut_kway.py +163 -0
- {ncut_pytorch-2.3.3 → ncut_pytorch-3.0.0.dev0}/ncut_pytorch/ncuts/ncut_nystrom.py +50 -44
- {ncut_pytorch-2.3.3 → ncut_pytorch-3.0.0.dev0}/ncut_pytorch/predictor/dino/dinov3.py +27 -0
- {ncut_pytorch-2.3.3 → ncut_pytorch-3.0.0.dev0}/ncut_pytorch/predictor/predictor.py +20 -13
- {ncut_pytorch-2.3.3 → ncut_pytorch-3.0.0.dev0}/ncut_pytorch/predictor/vision_predictor.py +4 -3
- ncut_pytorch-3.0.0.dev0/ncut_pytorch/utils/__init__.py +0 -0
- {ncut_pytorch-2.3.3 → ncut_pytorch-3.0.0.dev0}/ncut_pytorch/utils/device.py +2 -0
- ncut_pytorch-3.0.0.dev0/ncut_pytorch/utils/grad.py +154 -0
- ncut_pytorch-3.0.0.dev0/ncut_pytorch/utils/math.py +217 -0
- ncut_pytorch-3.0.0.dev0/ncut_pytorch/utils/sample.py +64 -0
- ncut_pytorch-3.0.0.dev0/ncut_pytorch/utils/sigma.py +65 -0
- {ncut_pytorch-2.3.3 → ncut_pytorch-3.0.0.dev0}/ncut_pytorch/utils/torch_mod.py +2 -0
- {ncut_pytorch-2.3.3 → ncut_pytorch-3.0.0.dev0}/ncut_pytorch.egg-info/PKG-INFO +3 -3
- {ncut_pytorch-2.3.3 → ncut_pytorch-3.0.0.dev0}/ncut_pytorch.egg-info/SOURCES.txt +2 -1
- {ncut_pytorch-2.3.3 → ncut_pytorch-3.0.0.dev0}/ncut_pytorch.egg-info/requires.txt +1 -1
- {ncut_pytorch-2.3.3 → ncut_pytorch-3.0.0.dev0}/pyproject.toml +2 -2
- ncut_pytorch-2.3.3/ncut_pytorch/__init__.py +0 -15
- ncut_pytorch-2.3.3/ncut_pytorch/color/__init__.py +0 -6
- ncut_pytorch-2.3.3/ncut_pytorch/ncuts/__init__.py +0 -3
- ncut_pytorch-2.3.3/ncut_pytorch/ncuts/ncut_click.py +0 -107
- ncut_pytorch-2.3.3/ncut_pytorch/ncuts/ncut_kway.py +0 -110
- ncut_pytorch-2.3.3/ncut_pytorch/utils/gamma.py +0 -60
- ncut_pytorch-2.3.3/ncut_pytorch/utils/grad.py +0 -41
- ncut_pytorch-2.3.3/ncut_pytorch/utils/math.py +0 -339
- ncut_pytorch-2.3.3/ncut_pytorch/utils/sample.py +0 -93
- {ncut_pytorch-2.3.3 → ncut_pytorch-3.0.0.dev0}/LICENSE +0 -0
- {ncut_pytorch-2.3.3/ncut_pytorch/utils → ncut_pytorch-3.0.0.dev0/ncut_pytorch/color}/__init__.py +0 -0
- {ncut_pytorch-2.3.3 → ncut_pytorch-3.0.0.dev0}/ncut_pytorch/predictor/__init__.py +0 -0
- {ncut_pytorch-2.3.3 → ncut_pytorch-3.0.0.dev0}/ncut_pytorch/predictor/dino/__init__.py +0 -0
- {ncut_pytorch-2.3.3 → ncut_pytorch-3.0.0.dev0}/ncut_pytorch/predictor/dino/api.py +0 -0
- {ncut_pytorch-2.3.3 → ncut_pytorch-3.0.0.dev0}/ncut_pytorch/predictor/dino/hires_dino.py +0 -0
- {ncut_pytorch-2.3.3 → ncut_pytorch-3.0.0.dev0}/ncut_pytorch/predictor/dino/lowres_dino.py +0 -0
- {ncut_pytorch-2.3.3 → ncut_pytorch-3.0.0.dev0}/ncut_pytorch/predictor/dino/patch.py +0 -0
- {ncut_pytorch-2.3.3 → ncut_pytorch-3.0.0.dev0}/ncut_pytorch/predictor/dino/transform.py +0 -0
- {ncut_pytorch-2.3.3 → ncut_pytorch-3.0.0.dev0}/ncut_pytorch/predictor/dino_predictor.py +0 -0
- {ncut_pytorch-2.3.3 → ncut_pytorch-3.0.0.dev0}/ncut_pytorch/predictor/jafar_predictor.py +0 -0
- {ncut_pytorch-2.3.3 → ncut_pytorch-3.0.0.dev0}/ncut_pytorch.egg-info/dependency_links.txt +0 -0
- {ncut_pytorch-2.3.3 → ncut_pytorch-3.0.0.dev0}/ncut_pytorch.egg-info/top_level.txt +0 -0
- {ncut_pytorch-2.3.3 → ncut_pytorch-3.0.0.dev0}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: ncut_pytorch
|
|
3
|
-
Version:
|
|
3
|
+
Version: 3.0.0.dev0
|
|
4
4
|
Summary: Normalized Cut and Spectral Embedding
|
|
5
5
|
Author-email: Huzheng Yang <huze.yann@gmail.com>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -17,7 +17,7 @@ Requires-Python: >=3
|
|
|
17
17
|
Description-Content-Type: text/markdown
|
|
18
18
|
License-File: LICENSE
|
|
19
19
|
Requires-Dist: numpy<2.0
|
|
20
|
-
Requires-Dist:
|
|
20
|
+
Requires-Dist: torch-quickfps
|
|
21
21
|
Requires-Dist: tqdm
|
|
22
22
|
Requires-Dist: pillow
|
|
23
23
|
Requires-Dist: opencv-python
|
|
@@ -41,7 +41,7 @@ Requires-Dist: ncut_pytorch[cmap,dev,torch]; extra == "all"
|
|
|
41
41
|
Dynamic: license-file
|
|
42
42
|
|
|
43
43
|
|
|
44
|
-
### [🌐Documentation
|
|
44
|
+
### [🌐Documentation](https://ncut-pytorch.readthedocs.io/) | [🤗HuggingFace Demo](https://huggingface.co/spaces/huzey/ncut-pytorch)
|
|
45
45
|
|
|
46
46
|
|
|
47
47
|
## Nyström Normalized Cut
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
|
|
2
|
-
### [🌐Documentation
|
|
2
|
+
### [🌐Documentation](https://ncut-pytorch.readthedocs.io/) | [🤗HuggingFace Demo](https://huggingface.co/spaces/huzey/ncut-pytorch)
|
|
3
3
|
|
|
4
4
|
|
|
5
5
|
## Nyström Normalized Cut
|
|
@@ -7,7 +7,6 @@ from numba.core.types import none
|
|
|
7
7
|
import numpy as np
|
|
8
8
|
import torch
|
|
9
9
|
|
|
10
|
-
from .mspace import mspace_viz_transform
|
|
11
10
|
from ncut_pytorch.ncuts.ncut_nystrom import nystrom_propagate
|
|
12
11
|
from ncut_pytorch.utils.math import quantile_normalize
|
|
13
12
|
from ncut_pytorch.utils.sample import farthest_point_sampling
|
|
@@ -20,7 +19,7 @@ def _identity(X: torch.Tensor) -> torch.Tensor:
|
|
|
20
19
|
def mspace_color(
|
|
21
20
|
X: torch.Tensor,
|
|
22
21
|
q: float = 0.95,
|
|
23
|
-
n_eig: Optional[int] =
|
|
22
|
+
n_eig: Optional[int] = 8,
|
|
24
23
|
n_dim: int = 3,
|
|
25
24
|
training_steps: int = 1000,
|
|
26
25
|
progress_bar: bool = False,
|
|
@@ -31,6 +30,7 @@ def mspace_color(
|
|
|
31
30
|
(torch.Tensor): Embedding in 2D, shape (n_samples, 2)
|
|
32
31
|
(torch.Tensor): RGB color for each data sample, shape (n_samples, 3)
|
|
33
32
|
"""
|
|
33
|
+
from .mspace import mspace_viz_transform
|
|
34
34
|
|
|
35
35
|
low_dim_embedding = mspace_viz_transform(
|
|
36
36
|
X=X,
|
|
@@ -4,8 +4,9 @@ import logging
|
|
|
4
4
|
from collections import defaultdict
|
|
5
5
|
from functools import partial
|
|
6
6
|
import warnings
|
|
7
|
-
|
|
8
7
|
import pytorch_lightning as pl
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
9
10
|
import torch
|
|
10
11
|
import torch.nn as nn
|
|
11
12
|
import torch.nn.functional as F
|
|
@@ -13,23 +14,8 @@ from torch.utils.data import TensorDataset
|
|
|
13
14
|
from tqdm import tqdm
|
|
14
15
|
|
|
15
16
|
from ncut_pytorch.utils.device import auto_device
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
logging.getLogger('lightning').setLevel(0)
|
|
19
|
-
logging.getLogger("pytorch_lightning").setLevel(0)
|
|
20
|
-
class IgnorePLFilter(logging.Filter):
|
|
21
|
-
def filter(self, record):
|
|
22
|
-
keywords = ['available:', 'CUDA', 'LOCAL_RANK:']
|
|
23
|
-
return not any(keyword in record.getMessage() for keyword in keywords)
|
|
24
|
-
logging.getLogger('pytorch_lightning.utilities.rank_zero').addFilter(IgnorePLFilter())
|
|
25
|
-
logging.getLogger('pytorch_lightning.accelerators.cuda').addFilter(IgnorePLFilter())
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
from ncut_pytorch.utils.math import rbf_affinity, cosine_affinity
|
|
29
|
-
from ncut_pytorch.ncuts.ncut_nystrom import _plain_ncut
|
|
30
|
-
from ncut_pytorch.ncuts.ncut_kway import kway_ncut
|
|
31
|
-
from ncut_pytorch.utils.gamma import find_gamma_by_degree_after_fps
|
|
32
|
-
from ncut_pytorch.utils.math import compute_riemann_curvature_loss, compute_boundary_loss, compute_repulsion_loss, compute_axis_align_loss, compute_attraction_loss, find_elbow
|
|
17
|
+
from ncut_pytorch import ncut_fn
|
|
18
|
+
from ncut_pytorch.utils.math import grad_safe_eig_solve, normalize_affinity, rbf_affinity, svd_lowrank, cosine_affinity
|
|
33
19
|
|
|
34
20
|
|
|
35
21
|
def _flag_ncut_loss(eigvec_gt, eigvec_hat, n_eig, weight):
|
|
@@ -65,10 +51,21 @@ def filter_closeby_eigval(eigvec, eigval, threshold=1e-3):
|
|
|
65
51
|
keep_idx = torch.where(eigval_diff > threshold)[0]
|
|
66
52
|
return eigvec[:, keep_idx], eigval[keep_idx]
|
|
67
53
|
|
|
68
|
-
def ncut_wrapper(features, n_eig,
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
54
|
+
def ncut_wrapper(features, n_eig, sigma=None):
|
|
55
|
+
# TODO: gradient is not stable.
|
|
56
|
+
|
|
57
|
+
# features.requires_grad_(True)
|
|
58
|
+
sigma = sigma or features.std(0).sum().item()
|
|
59
|
+
# eigvec, eigval = ncut_fn(features, n_eig, sigma=sigma, track_grad=True)
|
|
60
|
+
W = rbf_affinity(features, sigma=sigma)
|
|
61
|
+
# W = cosine_affinity(features, sigma=1.0)
|
|
62
|
+
A = normalize_affinity(W)
|
|
63
|
+
eigvec, eigval, _ = svd_lowrank(A, n_eig)
|
|
64
|
+
# eigval, eigvec = torch.linalg.eigh(A)
|
|
65
|
+
# eigvec = eigvec.flip(dims=[1])
|
|
66
|
+
# eigval = eigval.flip(dims=[0])
|
|
67
|
+
# eigvec = eigvec[:, :n_eig]
|
|
68
|
+
# eigval = eigval[:n_eig]
|
|
72
69
|
return eigvec, eigval
|
|
73
70
|
|
|
74
71
|
|
|
@@ -332,11 +329,11 @@ class TrainEncoder(pl.LightningModule):
|
|
|
332
329
|
input_feats, output_feats = batch
|
|
333
330
|
|
|
334
331
|
stored_eigvec_gt = {}
|
|
335
|
-
with torch.no_grad():
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
332
|
+
# with torch.no_grad():
|
|
333
|
+
# Compute eigvec_gt only once for each set of iterations
|
|
334
|
+
if self.eigvec_loss != 0:
|
|
335
|
+
eigvec_gt, eigval_gt = ncut_wrapper(input_feats, self.n_eig, sigma=input_feats.std(0).sum().item())
|
|
336
|
+
stored_eigvec_gt = eigvec_gt
|
|
340
337
|
|
|
341
338
|
|
|
342
339
|
# Run the same batch 10 times, updating parameters after each iteration
|
|
@@ -348,7 +345,9 @@ class TrainEncoder(pl.LightningModule):
|
|
|
348
345
|
|
|
349
346
|
total_loss = 0
|
|
350
347
|
if self.eigvec_loss != 0:
|
|
351
|
-
|
|
348
|
+
assert feats_compressed.requires_grad, "feats_compressed must be a tensor with requires_grad=True"
|
|
349
|
+
eigvec_hat, eigval_hat = ncut_wrapper(feats_compressed, self.n_eig, sigma=feats_compressed.std(0).sum().item())
|
|
350
|
+
assert eigvec_hat.requires_grad, "eigvec_hat must be a tensor with requires_grad=True"
|
|
352
351
|
eigvec_loss = flag_space_loss(stored_eigvec_gt, eigvec_hat, n_eig=self.n_eig, weight=None)
|
|
353
352
|
eigvec_loss = eigvec_loss * self.eigvec_loss
|
|
354
353
|
total_loss += eigvec_loss
|
|
@@ -361,37 +360,6 @@ class TrainEncoder(pl.LightningModule):
|
|
|
361
360
|
total_loss += recon_loss
|
|
362
361
|
self._log_loss(recon_loss, "recon", log_grad_norm=log_grad_norm)
|
|
363
362
|
|
|
364
|
-
|
|
365
|
-
if self.riemann_curvature_loss > 0:
|
|
366
|
-
riemann_curvature_loss = compute_riemann_curvature_loss(feats_compressed)
|
|
367
|
-
riemann_curvature_loss = riemann_curvature_loss * self.riemann_curvature_loss
|
|
368
|
-
total_loss += riemann_curvature_loss
|
|
369
|
-
self._log_loss(riemann_curvature_loss, "riemann_curvature", log_grad_norm=log_grad_norm)
|
|
370
|
-
|
|
371
|
-
if self.axis_align_loss > 0:
|
|
372
|
-
axis_align_loss = compute_axis_align_loss(feats_compressed)
|
|
373
|
-
axis_align_loss = axis_align_loss * self.axis_align_loss
|
|
374
|
-
total_loss += axis_align_loss
|
|
375
|
-
self._log_loss(axis_align_loss, "axis_align", log_grad_norm=log_grad_norm)
|
|
376
|
-
|
|
377
|
-
if self.repulsion_loss > 0:
|
|
378
|
-
repulsion_loss = compute_repulsion_loss(feats_compressed)
|
|
379
|
-
repulsion_loss = repulsion_loss * self.repulsion_loss
|
|
380
|
-
total_loss += repulsion_loss
|
|
381
|
-
self._log_loss(repulsion_loss, "repulsion", log_grad_norm=log_grad_norm)
|
|
382
|
-
|
|
383
|
-
if self.attraction_loss > 0:
|
|
384
|
-
attraction_loss = compute_attraction_loss(feats_compressed)
|
|
385
|
-
attraction_loss = attraction_loss * self.attraction_loss
|
|
386
|
-
total_loss += attraction_loss
|
|
387
|
-
self._log_loss(attraction_loss, "attraction", log_grad_norm=log_grad_norm)
|
|
388
|
-
|
|
389
|
-
if self.boundary_loss > 0:
|
|
390
|
-
boundary_loss = compute_boundary_loss(feats_compressed)
|
|
391
|
-
boundary_loss = boundary_loss * self.boundary_loss
|
|
392
|
-
total_loss += boundary_loss
|
|
393
|
-
self._log_loss(boundary_loss, "boundary", log_grad_norm=log_grad_norm)
|
|
394
|
-
|
|
395
363
|
if self.zero_center_loss > 0:
|
|
396
364
|
# zero_center_loss = feats_compressed.abs().mean()
|
|
397
365
|
zero_center_loss = (feats_compressed ** 2).mean()
|
|
@@ -499,48 +467,87 @@ class BestModelsAvgCallback(pl.Callback):
|
|
|
499
467
|
self.best_loss = float('inf')
|
|
500
468
|
|
|
501
469
|
|
|
470
|
+
def compute_repulsion_loss(
|
|
471
|
+
points: torch.Tensor, # [N, D]
|
|
472
|
+
) -> torch.Tensor:
|
|
473
|
+
"""Computes repulsion loss between points to prevent collapse."""
|
|
474
|
+
dist_matrix = torch.cdist(points, points)
|
|
475
|
+
mask = torch.eye(points.shape[0], device=points.device).bool()
|
|
476
|
+
dist_matrix = dist_matrix + mask * 1e10
|
|
477
|
+
nearest_dists, _ = torch.min(dist_matrix, dim=1)
|
|
478
|
+
repulsion = 1.0 / (nearest_dists + 0.01)
|
|
479
|
+
return torch.mean(repulsion)
|
|
480
|
+
|
|
481
|
+
|
|
502
482
|
def train_mspace_model(compress_feats, uncompress_feats, training_steps=500, decoder_training_steps=1000,
|
|
503
483
|
batch_size=1000, return_trainer=False, progress_bar=True,
|
|
504
484
|
logger=False, use_wandb=False, model_avg_window=3, **model_kwargs):
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
c_out = uncompress_feats.shape[1]
|
|
509
|
-
|
|
510
|
-
model = TrainEncoder(c_in, c_out, training_steps=training_steps, progress_bar=progress_bar, **model_kwargs)
|
|
485
|
+
# Disable lightning logs only for this function
|
|
486
|
+
original_lightning_level = logging.getLogger('lightning').level
|
|
487
|
+
original_pytorch_lightning_level = logging.getLogger("pytorch_lightning").level
|
|
511
488
|
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
trainer_args = {
|
|
516
|
-
'accelerator': auto_device(),
|
|
517
|
-
'devices': 1,
|
|
518
|
-
'enable_checkpointing': False,
|
|
519
|
-
'enable_progress_bar': False,
|
|
520
|
-
'enable_model_summary': False,
|
|
521
|
-
'callbacks': [BestModelsAvgCallback(top_k=model_avg_window)],
|
|
522
|
-
}
|
|
523
|
-
|
|
524
|
-
if use_wandb and not logger:
|
|
525
|
-
logger = pl.loggers.WandbLogger(project='mspace', name='mspace')
|
|
526
|
-
|
|
527
|
-
# train the autoencoder jointly
|
|
528
|
-
trainer = pl.Trainer(max_steps=training_steps, logger=logger, **trainer_args)
|
|
529
|
-
trainer.fit(model, dataloader)
|
|
489
|
+
logging.getLogger('lightning').setLevel(0)
|
|
490
|
+
logging.getLogger("pytorch_lightning").setLevel(0)
|
|
530
491
|
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
492
|
+
class IgnorePLFilter(logging.Filter):
|
|
493
|
+
def filter(self, record):
|
|
494
|
+
keywords = ['available:', 'CUDA', 'LOCAL_RANK:']
|
|
495
|
+
return not any(keyword in record.getMessage() for keyword in keywords)
|
|
496
|
+
|
|
497
|
+
pl_filter = IgnorePLFilter()
|
|
498
|
+
logger_rank_zero = logging.getLogger('pytorch_lightning.utilities.rank_zero')
|
|
499
|
+
logger_cuda = logging.getLogger('pytorch_lightning.accelerators.cuda')
|
|
500
|
+
logger_rank_zero.addFilter(pl_filter)
|
|
501
|
+
logger_cuda.addFilter(pl_filter)
|
|
502
|
+
|
|
503
|
+
try:
|
|
504
|
+
compress_feats = compress_feats.float().cpu()
|
|
505
|
+
uncompress_feats = uncompress_feats.float().cpu()
|
|
506
|
+
l, c_in = compress_feats.shape
|
|
507
|
+
c_out = uncompress_feats.shape[1]
|
|
542
508
|
|
|
543
|
-
|
|
509
|
+
model = TrainEncoder(c_in, c_out, training_steps=training_steps, progress_bar=progress_bar, **model_kwargs)
|
|
510
|
+
|
|
511
|
+
dataset = TensorDataset(compress_feats, uncompress_feats)
|
|
512
|
+
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
|
|
513
|
+
|
|
514
|
+
trainer_args = {
|
|
515
|
+
'accelerator': auto_device(),
|
|
516
|
+
'devices': 1,
|
|
517
|
+
'enable_checkpointing': False,
|
|
518
|
+
'enable_progress_bar': False,
|
|
519
|
+
'enable_model_summary': False,
|
|
520
|
+
'callbacks': [BestModelsAvgCallback(top_k=model_avg_window)],
|
|
521
|
+
}
|
|
522
|
+
|
|
523
|
+
if use_wandb and not logger:
|
|
524
|
+
logger = pl.loggers.WandbLogger(project='mspace', name='mspace')
|
|
525
|
+
|
|
526
|
+
# train the autoencoder jointly
|
|
527
|
+
trainer = pl.Trainer(max_steps=training_steps, logger=logger, **trainer_args)
|
|
528
|
+
trainer.fit(model, dataloader)
|
|
529
|
+
|
|
530
|
+
mspace_ae = model.mspace_ae
|
|
531
|
+
|
|
532
|
+
if decoder_training_steps > 0:
|
|
533
|
+
# train the decoder only
|
|
534
|
+
trainer = pl.Trainer(max_steps=decoder_training_steps, logger=logger, **trainer_args)
|
|
535
|
+
model2 = TrainDecoder(mspace_ae, progress_bar=progress_bar, training_steps=decoder_training_steps, **model_kwargs)
|
|
536
|
+
trainer.fit(model2, dataloader)
|
|
537
|
+
model.mspace_ae = model2.mspace_ae
|
|
538
|
+
|
|
539
|
+
if return_trainer:
|
|
540
|
+
result = model, trainer
|
|
541
|
+
else:
|
|
542
|
+
result = model
|
|
543
|
+
finally:
|
|
544
|
+
# Restore original logging configuration
|
|
545
|
+
logging.getLogger('lightning').setLevel(original_lightning_level)
|
|
546
|
+
logging.getLogger("pytorch_lightning").setLevel(original_pytorch_lightning_level)
|
|
547
|
+
logger_rank_zero.removeFilter(pl_filter)
|
|
548
|
+
logger_cuda.removeFilter(pl_filter)
|
|
549
|
+
|
|
550
|
+
return result
|
|
544
551
|
|
|
545
552
|
|
|
546
553
|
def try_train_mspace(*args, **kwargs):
|
|
@@ -576,7 +583,8 @@ def try_train_mspace(*args, **kwargs):
|
|
|
576
583
|
|
|
577
584
|
def mspace_viz_transform(X, return_model=False, **kwargs):
|
|
578
585
|
X = X.float().cpu()
|
|
579
|
-
model, trainer = try_train_mspace(X, X, return_trainer=True, **kwargs)
|
|
586
|
+
# model, trainer = try_train_mspace(X, X, return_trainer=True, **kwargs)
|
|
587
|
+
model, trainer = train_mspace_model(X, X, return_trainer=True, **kwargs)
|
|
580
588
|
|
|
581
589
|
batch_size = kwargs.get('batch_size', 1000)
|
|
582
590
|
test_loader = torch.utils.data.DataLoader(TensorDataset(X), batch_size=batch_size, shuffle=False, num_workers=0)
|