careamics 0.0.8__py3-none-any.whl → 0.0.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/lvae_training/eval_utils.py +21 -8
- {careamics-0.0.8.dist-info → careamics-0.0.9.dist-info}/METADATA +1 -1
- {careamics-0.0.8.dist-info → careamics-0.0.9.dist-info}/RECORD +6 -6
- {careamics-0.0.8.dist-info → careamics-0.0.9.dist-info}/WHEEL +0 -0
- {careamics-0.0.8.dist-info → careamics-0.0.9.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.8.dist-info → careamics-0.0.9.dist-info}/licenses/LICENSE +0 -0
|
@@ -107,6 +107,15 @@ def get_first_index(bin_count, quantile):
|
|
|
107
107
|
return None
|
|
108
108
|
|
|
109
109
|
|
|
110
|
+
def get_device():
|
|
111
|
+
if torch.cuda.is_available():
|
|
112
|
+
return "cuda"
|
|
113
|
+
elif torch.backends.mps.is_available():
|
|
114
|
+
return "mps"
|
|
115
|
+
else:
|
|
116
|
+
return "cpu"
|
|
117
|
+
|
|
118
|
+
|
|
110
119
|
def show_for_one(
|
|
111
120
|
idx,
|
|
112
121
|
val_dset,
|
|
@@ -516,7 +525,7 @@ def get_predictions(
|
|
|
516
525
|
num_workers=num_workers,
|
|
517
526
|
)
|
|
518
527
|
# get filename without extension and path
|
|
519
|
-
filename =
|
|
528
|
+
filename = d._fpath.name
|
|
520
529
|
multifile_stitched_predictions[filename] = stitched_predictions
|
|
521
530
|
multifile_stitched_stds[filename] = stitched_stds
|
|
522
531
|
return (
|
|
@@ -534,7 +543,7 @@ def get_predictions(
|
|
|
534
543
|
num_workers=num_workers,
|
|
535
544
|
)
|
|
536
545
|
# get filename without extension and path
|
|
537
|
-
filename =
|
|
546
|
+
filename = dset._fpath.name
|
|
538
547
|
return (
|
|
539
548
|
{filename: stitched_predictions},
|
|
540
549
|
{filename: stitched_stds},
|
|
@@ -553,6 +562,8 @@ def get_single_file_predictions(
|
|
|
553
562
|
if tile_size and grid_size:
|
|
554
563
|
dset.set_img_sz(tile_size, grid_size)
|
|
555
564
|
|
|
565
|
+
device = get_device()
|
|
566
|
+
|
|
556
567
|
dloader = DataLoader(
|
|
557
568
|
dset,
|
|
558
569
|
pin_memory=False,
|
|
@@ -561,14 +572,14 @@ def get_single_file_predictions(
|
|
|
561
572
|
batch_size=batch_size,
|
|
562
573
|
)
|
|
563
574
|
model.eval()
|
|
564
|
-
model.
|
|
575
|
+
model.to(device)
|
|
565
576
|
tiles = []
|
|
566
577
|
logvar_arr = []
|
|
567
578
|
with torch.no_grad():
|
|
568
579
|
for batch in tqdm(dloader, desc="Predicting tiles"):
|
|
569
580
|
inp, tar = batch
|
|
570
|
-
inp = inp.
|
|
571
|
-
tar = tar.
|
|
581
|
+
inp = inp.to(device)
|
|
582
|
+
tar = tar.to(device)
|
|
572
583
|
|
|
573
584
|
# get model output
|
|
574
585
|
rec, _ = model(inp)
|
|
@@ -597,6 +608,8 @@ def get_single_file_mmse(
|
|
|
597
608
|
num_workers: int = 4,
|
|
598
609
|
) -> tuple[np.ndarray, np.ndarray]:
|
|
599
610
|
"""Get patch-wise predictions from a model for a single file dataset."""
|
|
611
|
+
device = get_device()
|
|
612
|
+
|
|
600
613
|
dloader = DataLoader(
|
|
601
614
|
dset,
|
|
602
615
|
pin_memory=False,
|
|
@@ -608,15 +621,15 @@ def get_single_file_mmse(
|
|
|
608
621
|
dset.set_img_sz(tile_size, grid_size)
|
|
609
622
|
|
|
610
623
|
model.eval()
|
|
611
|
-
model.
|
|
624
|
+
model.to(device)
|
|
612
625
|
tile_mmse = []
|
|
613
626
|
tile_stds = []
|
|
614
627
|
logvar_arr = []
|
|
615
628
|
with torch.no_grad():
|
|
616
629
|
for batch in tqdm(dloader, desc="Predicting tiles"):
|
|
617
630
|
inp, tar = batch
|
|
618
|
-
inp = inp.
|
|
619
|
-
tar = tar.
|
|
631
|
+
inp = inp.to(device)
|
|
632
|
+
tar = tar.to(device)
|
|
620
633
|
|
|
621
634
|
rec_img_list = []
|
|
622
635
|
for _ in range(mmse_count):
|
|
@@ -106,7 +106,7 @@ careamics/losses/lvae/loss_utils.py,sha256=QxzA2N1TglR4H0X0uyTWWytDagE1lA9IB_TK1
|
|
|
106
106
|
careamics/losses/lvae/losses.py,sha256=wHT1dx04BZ_OI-_S7cFQ5hFmMetm6FSnuZfwZBBtIpY,17977
|
|
107
107
|
careamics/lvae_training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
108
108
|
careamics/lvae_training/calibration.py,sha256=xHbiLcY2csYos3s7rRSqp7P7G-9wzULcSo1JfVzfIjE,7239
|
|
109
|
-
careamics/lvae_training/eval_utils.py,sha256=
|
|
109
|
+
careamics/lvae_training/eval_utils.py,sha256=FxZmmT6vMRluLYnnCEQtLcz5Q45OAqmxXbQo6KPbQEk,30372
|
|
110
110
|
careamics/lvae_training/get_config.py,sha256=dwVfaQS7nzjQss0E1gGLUpQpjPcOWwLgIhbu3Z0I1rg,3068
|
|
111
111
|
careamics/lvae_training/lightning_module.py,sha256=ryr7iHqCMzCl5esi6_gEcnKFDQkMrw0EXK9Zfgv1Nek,27186
|
|
112
112
|
careamics/lvae_training/metrics.py,sha256=KTDAKhe3vh-YxzGibjtkIG2nnUyujbnwqX4xGwaRXwE,6718
|
|
@@ -171,8 +171,8 @@ careamics/utils/ram.py,sha256=tksyn8dVX_iJXmrDZDGub32hFZWIaNxnMheO5G1p43I,244
|
|
|
171
171
|
careamics/utils/receptive_field.py,sha256=Y2h4c8S6glX3qcx5KHDmO17Kkuyey9voxfoXyqcAfiM,3296
|
|
172
172
|
careamics/utils/serializers.py,sha256=mILUhz75IMpGKnEzcYu9hlOPG8YIiIW09fk6eZM7Y8k,1427
|
|
173
173
|
careamics/utils/torch_utils.py,sha256=_Cf3HdlIRl5hxfpUg9aofCSlcW7GSsIJxsbSORXko0U,3010
|
|
174
|
-
careamics-0.0.
|
|
175
|
-
careamics-0.0.
|
|
176
|
-
careamics-0.0.
|
|
177
|
-
careamics-0.0.
|
|
178
|
-
careamics-0.0.
|
|
174
|
+
careamics-0.0.9.dist-info/METADATA,sha256=WX-n7dq4Qr5djLTSxELRii5NKHf_XqNz_vuJcyNEE7Y,3967
|
|
175
|
+
careamics-0.0.9.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
176
|
+
careamics-0.0.9.dist-info/entry_points.txt,sha256=2fSNVXJWDJgFLATVj7MkjFNvpl53amG8tUzC3jf7G1s,53
|
|
177
|
+
careamics-0.0.9.dist-info/licenses/LICENSE,sha256=6zdNW-k_xHRKYWUf9tDI_ZplUciFHyj0g16DYuZ2udw,1509
|
|
178
|
+
careamics-0.0.9.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|