careamics 0.0.7__py3-none-any.whl → 0.0.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

@@ -107,6 +107,15 @@ def get_first_index(bin_count, quantile):
107
107
  return None
108
108
 
109
109
 
110
+ def get_device():
111
+ if torch.cuda.is_available():
112
+ return "cuda"
113
+ elif torch.backends.mps.is_available():
114
+ return "mps"
115
+ else:
116
+ return "cpu"
117
+
118
+
110
119
  def show_for_one(
111
120
  idx,
112
121
  val_dset,
@@ -470,6 +479,7 @@ def get_predictions(
470
479
  dset: Dataset,
471
480
  batch_size: int,
472
481
  tile_size: Optional[tuple[int, int]] = None,
482
+ grid_size: Optional[int] = None,
473
483
  mmse_count: int = 1,
474
484
  num_workers: int = 4,
475
485
  ) -> tuple[dict, dict, dict]:
@@ -510,11 +520,12 @@ def get_predictions(
510
520
  dset=d,
511
521
  batch_size=batch_size,
512
522
  tile_size=tile_size,
523
+ grid_size=grid_size,
513
524
  mmse_count=mmse_count,
514
525
  num_workers=num_workers,
515
526
  )
516
527
  # get filename without extension and path
517
- filename = str(d._fpath).split("/")[-1].split(".")[0]
528
+ filename = d._fpath.name
518
529
  multifile_stitched_predictions[filename] = stitched_predictions
519
530
  multifile_stitched_stds[filename] = stitched_stds
520
531
  return (
@@ -527,11 +538,12 @@ def get_predictions(
527
538
  dset=dset,
528
539
  batch_size=batch_size,
529
540
  tile_size=tile_size,
541
+ grid_size=grid_size,
530
542
  mmse_count=mmse_count,
531
543
  num_workers=num_workers,
532
544
  )
533
545
  # get filename without extension and path
534
- filename = str(dset._fpath).split("/")[-1].split(".")[0]
546
+ filename = dset._fpath.name
535
547
  return (
536
548
  {filename: stitched_predictions},
537
549
  {filename: stitched_stds},
@@ -550,6 +562,8 @@ def get_single_file_predictions(
550
562
  if tile_size and grid_size:
551
563
  dset.set_img_sz(tile_size, grid_size)
552
564
 
565
+ device = get_device()
566
+
553
567
  dloader = DataLoader(
554
568
  dset,
555
569
  pin_memory=False,
@@ -558,14 +572,14 @@ def get_single_file_predictions(
558
572
  batch_size=batch_size,
559
573
  )
560
574
  model.eval()
561
- model.cuda()
575
+ model.to(device)
562
576
  tiles = []
563
577
  logvar_arr = []
564
578
  with torch.no_grad():
565
579
  for batch in tqdm(dloader, desc="Predicting tiles"):
566
580
  inp, tar = batch
567
- inp = inp.cuda()
568
- tar = tar.cuda()
581
+ inp = inp.to(device)
582
+ tar = tar.to(device)
569
583
 
570
584
  # get model output
571
585
  rec, _ = model(inp)
@@ -589,10 +603,13 @@ def get_single_file_mmse(
589
603
  dset: Dataset,
590
604
  batch_size: int,
591
605
  tile_size: Optional[tuple[int, int]] = None,
606
+ grid_size: Optional[int] = None,
592
607
  mmse_count: int = 1,
593
608
  num_workers: int = 4,
594
609
  ) -> tuple[np.ndarray, np.ndarray]:
595
610
  """Get patch-wise predictions from a model for a single file dataset."""
611
+ device = get_device()
612
+
596
613
  dloader = DataLoader(
597
614
  dset,
598
615
  pin_memory=False,
@@ -600,18 +617,19 @@ def get_single_file_mmse(
600
617
  shuffle=False,
601
618
  batch_size=batch_size,
602
619
  )
603
- if tile_size:
604
- dset.set_img_sz(tile_size, tile_size[-1] // 2)
620
+ if tile_size and grid_size:
621
+ dset.set_img_sz(tile_size, grid_size)
622
+
605
623
  model.eval()
606
- model.cuda()
624
+ model.to(device)
607
625
  tile_mmse = []
608
626
  tile_stds = []
609
627
  logvar_arr = []
610
628
  with torch.no_grad():
611
629
  for batch in tqdm(dloader, desc="Predicting tiles"):
612
630
  inp, tar = batch
613
- inp = inp.cuda()
614
- tar = tar.cuda()
631
+ inp = inp.to(device)
632
+ tar = tar.to(device)
615
633
 
616
634
  rec_img_list = []
617
635
  for _ in range(mmse_count):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: careamics
3
- Version: 0.0.7
3
+ Version: 0.0.9
4
4
  Summary: Toolbox for running N2V and friends.
5
5
  Project-URL: homepage, https://careamics.github.io/
6
6
  Project-URL: repository, https://github.com/CAREamics/careamics
@@ -106,7 +106,7 @@ careamics/losses/lvae/loss_utils.py,sha256=QxzA2N1TglR4H0X0uyTWWytDagE1lA9IB_TK1
106
106
  careamics/losses/lvae/losses.py,sha256=wHT1dx04BZ_OI-_S7cFQ5hFmMetm6FSnuZfwZBBtIpY,17977
107
107
  careamics/lvae_training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
108
108
  careamics/lvae_training/calibration.py,sha256=xHbiLcY2csYos3s7rRSqp7P7G-9wzULcSo1JfVzfIjE,7239
109
- careamics/lvae_training/eval_utils.py,sha256=7N1thslU4IU1lM1tGg3-wa8AFf5_R2lOSQ7ZZ91AUII,30030
109
+ careamics/lvae_training/eval_utils.py,sha256=FxZmmT6vMRluLYnnCEQtLcz5Q45OAqmxXbQo6KPbQEk,30372
110
110
  careamics/lvae_training/get_config.py,sha256=dwVfaQS7nzjQss0E1gGLUpQpjPcOWwLgIhbu3Z0I1rg,3068
111
111
  careamics/lvae_training/lightning_module.py,sha256=ryr7iHqCMzCl5esi6_gEcnKFDQkMrw0EXK9Zfgv1Nek,27186
112
112
  careamics/lvae_training/metrics.py,sha256=KTDAKhe3vh-YxzGibjtkIG2nnUyujbnwqX4xGwaRXwE,6718
@@ -171,8 +171,8 @@ careamics/utils/ram.py,sha256=tksyn8dVX_iJXmrDZDGub32hFZWIaNxnMheO5G1p43I,244
171
171
  careamics/utils/receptive_field.py,sha256=Y2h4c8S6glX3qcx5KHDmO17Kkuyey9voxfoXyqcAfiM,3296
172
172
  careamics/utils/serializers.py,sha256=mILUhz75IMpGKnEzcYu9hlOPG8YIiIW09fk6eZM7Y8k,1427
173
173
  careamics/utils/torch_utils.py,sha256=_Cf3HdlIRl5hxfpUg9aofCSlcW7GSsIJxsbSORXko0U,3010
174
- careamics-0.0.7.dist-info/METADATA,sha256=K3w_i8E8INNeZwEQdVtGHDFS_C0031Bfqzn23e8SSO4,3967
175
- careamics-0.0.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
176
- careamics-0.0.7.dist-info/entry_points.txt,sha256=2fSNVXJWDJgFLATVj7MkjFNvpl53amG8tUzC3jf7G1s,53
177
- careamics-0.0.7.dist-info/licenses/LICENSE,sha256=6zdNW-k_xHRKYWUf9tDI_ZplUciFHyj0g16DYuZ2udw,1509
178
- careamics-0.0.7.dist-info/RECORD,,
174
+ careamics-0.0.9.dist-info/METADATA,sha256=WX-n7dq4Qr5djLTSxELRii5NKHf_XqNz_vuJcyNEE7Y,3967
175
+ careamics-0.0.9.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
176
+ careamics-0.0.9.dist-info/entry_points.txt,sha256=2fSNVXJWDJgFLATVj7MkjFNvpl53amG8tUzC3jf7G1s,53
177
+ careamics-0.0.9.dist-info/licenses/LICENSE,sha256=6zdNW-k_xHRKYWUf9tDI_ZplUciFHyj0g16DYuZ2udw,1509
178
+ careamics-0.0.9.dist-info/RECORD,,