careamics 0.0.8__py3-none-any.whl → 0.0.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

@@ -107,6 +107,15 @@ def get_first_index(bin_count, quantile):
107
107
  return None
108
108
 
109
109
 
110
+ def get_device():
111
+ if torch.cuda.is_available():
112
+ return "cuda"
113
+ elif torch.backends.mps.is_available():
114
+ return "mps"
115
+ else:
116
+ return "cpu"
117
+
118
+
110
119
  def show_for_one(
111
120
  idx,
112
121
  val_dset,
@@ -516,7 +525,7 @@ def get_predictions(
516
525
  num_workers=num_workers,
517
526
  )
518
527
  # get filename without extension and path
519
- filename = str(d._fpath).split("/")[-1].split(".")[0]
528
+ filename = d._fpath.name
520
529
  multifile_stitched_predictions[filename] = stitched_predictions
521
530
  multifile_stitched_stds[filename] = stitched_stds
522
531
  return (
@@ -534,7 +543,7 @@ def get_predictions(
534
543
  num_workers=num_workers,
535
544
  )
536
545
  # get filename without extension and path
537
- filename = str(dset._fpath).split("/")[-1].split(".")[0]
546
+ filename = dset._fpath.name
538
547
  return (
539
548
  {filename: stitched_predictions},
540
549
  {filename: stitched_stds},
@@ -553,6 +562,8 @@ def get_single_file_predictions(
553
562
  if tile_size and grid_size:
554
563
  dset.set_img_sz(tile_size, grid_size)
555
564
 
565
+ device = get_device()
566
+
556
567
  dloader = DataLoader(
557
568
  dset,
558
569
  pin_memory=False,
@@ -561,14 +572,14 @@ def get_single_file_predictions(
561
572
  batch_size=batch_size,
562
573
  )
563
574
  model.eval()
564
- model.cuda()
575
+ model.to(device)
565
576
  tiles = []
566
577
  logvar_arr = []
567
578
  with torch.no_grad():
568
579
  for batch in tqdm(dloader, desc="Predicting tiles"):
569
580
  inp, tar = batch
570
- inp = inp.cuda()
571
- tar = tar.cuda()
581
+ inp = inp.to(device)
582
+ tar = tar.to(device)
572
583
 
573
584
  # get model output
574
585
  rec, _ = model(inp)
@@ -597,6 +608,8 @@ def get_single_file_mmse(
597
608
  num_workers: int = 4,
598
609
  ) -> tuple[np.ndarray, np.ndarray]:
599
610
  """Get patch-wise predictions from a model for a single file dataset."""
611
+ device = get_device()
612
+
600
613
  dloader = DataLoader(
601
614
  dset,
602
615
  pin_memory=False,
@@ -608,15 +621,15 @@ def get_single_file_mmse(
608
621
  dset.set_img_sz(tile_size, grid_size)
609
622
 
610
623
  model.eval()
611
- model.cuda()
624
+ model.to(device)
612
625
  tile_mmse = []
613
626
  tile_stds = []
614
627
  logvar_arr = []
615
628
  with torch.no_grad():
616
629
  for batch in tqdm(dloader, desc="Predicting tiles"):
617
630
  inp, tar = batch
618
- inp = inp.cuda()
619
- tar = tar.cuda()
631
+ inp = inp.to(device)
632
+ tar = tar.to(device)
620
633
 
621
634
  rec_img_list = []
622
635
  for _ in range(mmse_count):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: careamics
3
- Version: 0.0.8
3
+ Version: 0.0.9
4
4
  Summary: Toolbox for running N2V and friends.
5
5
  Project-URL: homepage, https://careamics.github.io/
6
6
  Project-URL: repository, https://github.com/CAREamics/careamics
@@ -106,7 +106,7 @@ careamics/losses/lvae/loss_utils.py,sha256=QxzA2N1TglR4H0X0uyTWWytDagE1lA9IB_TK1
106
106
  careamics/losses/lvae/losses.py,sha256=wHT1dx04BZ_OI-_S7cFQ5hFmMetm6FSnuZfwZBBtIpY,17977
107
107
  careamics/lvae_training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
108
108
  careamics/lvae_training/calibration.py,sha256=xHbiLcY2csYos3s7rRSqp7P7G-9wzULcSo1JfVzfIjE,7239
109
- careamics/lvae_training/eval_utils.py,sha256=2Ij1L1enIwjox16cPrhGzkcVQooBzN_mcC9ma94mJrE,30180
109
+ careamics/lvae_training/eval_utils.py,sha256=FxZmmT6vMRluLYnnCEQtLcz5Q45OAqmxXbQo6KPbQEk,30372
110
110
  careamics/lvae_training/get_config.py,sha256=dwVfaQS7nzjQss0E1gGLUpQpjPcOWwLgIhbu3Z0I1rg,3068
111
111
  careamics/lvae_training/lightning_module.py,sha256=ryr7iHqCMzCl5esi6_gEcnKFDQkMrw0EXK9Zfgv1Nek,27186
112
112
  careamics/lvae_training/metrics.py,sha256=KTDAKhe3vh-YxzGibjtkIG2nnUyujbnwqX4xGwaRXwE,6718
@@ -171,8 +171,8 @@ careamics/utils/ram.py,sha256=tksyn8dVX_iJXmrDZDGub32hFZWIaNxnMheO5G1p43I,244
171
171
  careamics/utils/receptive_field.py,sha256=Y2h4c8S6glX3qcx5KHDmO17Kkuyey9voxfoXyqcAfiM,3296
172
172
  careamics/utils/serializers.py,sha256=mILUhz75IMpGKnEzcYu9hlOPG8YIiIW09fk6eZM7Y8k,1427
173
173
  careamics/utils/torch_utils.py,sha256=_Cf3HdlIRl5hxfpUg9aofCSlcW7GSsIJxsbSORXko0U,3010
174
- careamics-0.0.8.dist-info/METADATA,sha256=jBK6zEhnACTuo9igH-PT-5sphN-tj2-b20j1E2zE7_w,3967
175
- careamics-0.0.8.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
176
- careamics-0.0.8.dist-info/entry_points.txt,sha256=2fSNVXJWDJgFLATVj7MkjFNvpl53amG8tUzC3jf7G1s,53
177
- careamics-0.0.8.dist-info/licenses/LICENSE,sha256=6zdNW-k_xHRKYWUf9tDI_ZplUciFHyj0g16DYuZ2udw,1509
178
- careamics-0.0.8.dist-info/RECORD,,
174
+ careamics-0.0.9.dist-info/METADATA,sha256=WX-n7dq4Qr5djLTSxELRii5NKHf_XqNz_vuJcyNEE7Y,3967
175
+ careamics-0.0.9.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
176
+ careamics-0.0.9.dist-info/entry_points.txt,sha256=2fSNVXJWDJgFLATVj7MkjFNvpl53amG8tUzC3jf7G1s,53
177
+ careamics-0.0.9.dist-info/licenses/LICENSE,sha256=6zdNW-k_xHRKYWUf9tDI_ZplUciFHyj0g16DYuZ2udw,1509
178
+ careamics-0.0.9.dist-info/RECORD,,