careamics 0.1.0rc1__py3-none-any.whl → 0.1.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

careamics/engine.py CHANGED
@@ -10,12 +10,11 @@ from typing import Any, Dict, List, Optional, Tuple, Union
10
10
 
11
11
  import numpy as np
12
12
  import torch
13
- from bioimageio.spec.model.raw_nodes import Model as BioimageModel
14
13
  from torch.utils.data import DataLoader, TensorDataset
15
14
 
16
15
  from .bioimage import (
17
- build_zip_model,
18
16
  get_default_model_specs,
17
+ save_bioimage_model,
19
18
  )
20
19
  from .config import Configuration, load_configuration
21
20
  from .dataset.prepare_dataset import (
@@ -32,7 +31,7 @@ from .prediction import (
32
31
  )
33
32
  from .utils import (
34
33
  MetricTracker,
35
- check_array_validity,
34
+ add_axes,
36
35
  denormalize,
37
36
  get_device,
38
37
  normalize,
@@ -40,7 +39,6 @@ from .utils import (
40
39
  from .utils.logging import ProgressBar, get_logger
41
40
 
42
41
 
43
- # TODO: refactor private methods and bioimage.io to other modules
44
42
  class Engine:
45
43
  """
46
44
  Class allowing training of a model and subsequent prediction.
@@ -165,47 +163,52 @@ class Engine:
165
163
  self.scaler,
166
164
  self.cfg,
167
165
  ) = create_model(config=self.cfg, model_path=model_path, device=self.device)
166
+ assert self.cfg is not None
168
167
 
169
168
  # create loss function
170
- if self.cfg is not None:
171
- self.loss_func = create_loss_function(self.cfg)
169
+ self.loss_func = create_loss_function(self.cfg)
170
+
171
+ # Set logging
172
+ log_path = self.cfg.working_directory / "log.txt"
173
+ self.logger = get_logger(__name__, log_path=log_path)
172
174
 
173
- # Set logging
174
- log_path = self.cfg.working_directory / "log.txt"
175
- self.logger = get_logger(__name__, log_path=log_path)
175
+ # wandb
176
+ self.use_wandb = self.cfg.training.use_wandb
176
177
 
177
- # wandb
178
- self.use_wandb = self.cfg.training.use_wandb
178
+ if self.use_wandb:
179
+ try:
180
+ from wandb.errors import UsageError
181
+
182
+ from careamics.utils.wandb import WandBLogging
179
183
 
180
- if self.use_wandb:
181
184
  try:
182
- from wandb.errors import UsageError
183
-
184
- from careamics.utils.wandb import WandBLogging
185
-
186
- try:
187
- self.wandb = WandBLogging(
188
- experiment_name=self.cfg.experiment_name,
189
- log_path=self.cfg.working_directory,
190
- config=self.cfg,
191
- model_to_watch=self.model,
192
- )
193
- except UsageError as e:
194
- self.logger.warning(
195
- f"Wandb usage error, using default logger. Check whether "
196
- f"wandb correctly configured:\n"
197
- f"{e}"
198
- )
199
- self.use_wandb = False
200
-
201
- except ModuleNotFoundError:
185
+ self.wandb = WandBLogging(
186
+ experiment_name=self.cfg.experiment_name,
187
+ log_path=self.cfg.working_directory,
188
+ config=self.cfg,
189
+ model_to_watch=self.model,
190
+ )
191
+ except UsageError as e:
202
192
  self.logger.warning(
203
- "Wandb not installed, using default logger. Try pip install "
204
- "wandb"
193
+ f"Wandb usage error, using default logger. Check whether "
194
+ f"wandb correctly configured:\n"
195
+ f"{e}"
205
196
  )
206
197
  self.use_wandb = False
207
- else:
208
- raise ValueError("Configuration is not defined.")
198
+
199
+ except ModuleNotFoundError:
200
+ self.logger.warning(
201
+ "Wandb not installed, using default logger. Try pip install "
202
+ "wandb"
203
+ )
204
+ self.use_wandb = False
205
+
206
+ # BMZ inputs/outputs placeholders, filled during validation
207
+ self._input = None
208
+ self._outputs = None
209
+
210
+ # torch version
211
+ self.torch_version = torch.__version__
209
212
 
210
213
  def train(
211
214
  self,
@@ -387,6 +390,12 @@ class Engine:
387
390
 
388
391
  with torch.no_grad():
389
392
  for patch, *auxillary in val_loader:
393
+ # if inputs is None, record a single patch
394
+ if self._input is None:
395
+ # patch has dimension SC(Z)YX
396
+ self._input = patch.clone().detach().cpu().numpy()
397
+
398
+ # evaluate
390
399
  outputs = self.model(patch.to(self.device))
391
400
  loss = self.loss_func(
392
401
  outputs, *[a.to(self.device) for a in auxillary], self.device
@@ -409,10 +418,18 @@ class Engine:
409
418
  The Engine must have previously been trained and mean/std be specified in
410
419
  its configuration.
411
420
 
421
+ Data should be compatible with the axes, either from the configuration or
422
+ as passed using the `axes` parameter. If the batch and channel dimensions are
423
+ missing, then singleton dimensions are added.
424
+
412
425
  To use tiling, both `tile_shape` and `overlaps` must be specified, have same
413
426
  length, be divisible by 2 and greater than 0. Finally, the overlaps must be
414
427
  smaller than the tiles.
415
428
 
429
+ By setting `tta` to `True`, the prediction is performed using test time
430
+ augmentation, meaning that the input is augmented and the prediction is averaged
431
+ over the augmentations.
432
+
416
433
  Parameters
417
434
  ----------
418
435
  input : Union[np.ndarra, str, Path]
@@ -497,6 +514,18 @@ class Engine:
497
514
  UserWarning
498
515
  If the samples have different shapes, the prediction then returns a list.
499
516
  """
517
+ # checks are done here to satisfy mypy
518
+ # check that configuration exists
519
+ if self.cfg is None:
520
+ raise ValueError("Configuration is not defined, cannot predict.")
521
+
522
+ # Check that the mean and std are there (= has been trained)
523
+ if not self.cfg.data.mean or not self.cfg.data.std:
524
+ raise ValueError(
525
+ "Mean or std are not specified in the configuration, prediction cannot "
526
+ "be performed."
527
+ )
528
+
500
529
  prediction = []
501
530
  tiles = []
502
531
  stitching_data = []
@@ -513,9 +542,7 @@ class Engine:
513
542
  predicted_augments = []
514
543
  for augmented_tile in augmented_tiles:
515
544
  augmented_pred = self.model(augmented_tile.to(self.device))
516
- predicted_augments.append(
517
- augmented_pred.squeeze().cpu().numpy()
518
- )
545
+ predicted_augments.append(augmented_pred.cpu())
519
546
  tiles.append(tta_backward(predicted_augments).squeeze())
520
547
  else:
521
548
  tiles.append(
@@ -529,8 +556,8 @@ class Engine:
529
556
  predicted_sample = stitch_prediction(tiles, stitching_data)
530
557
  predicted_sample = denormalize(
531
558
  predicted_sample,
532
- float(self.cfg.data.mean), # type: ignore
533
- float(self.cfg.data.std), # type: ignore
559
+ float(self.cfg.data.mean),
560
+ float(self.cfg.data.std),
534
561
  )
535
562
  prediction.append(predicted_sample)
536
563
  tiles.clear()
@@ -566,6 +593,18 @@ class Engine:
566
593
  np.ndarray
567
594
  Predicted image.
568
595
  """
596
+ # checks are done here to satisfy mypy
597
+ # check that configuration exists
598
+ if self.cfg is None:
599
+ raise ValueError("Configuration is not defined, cannot predict.")
600
+
601
+ # Check that the mean and std are there (= has been trained)
602
+ if not self.cfg.data.mean or not self.cfg.data.std:
603
+ raise ValueError(
604
+ "Mean or std are not specified in the configuration, prediction cannot "
605
+ "be performed."
606
+ )
607
+
569
608
  prediction = []
570
609
  with torch.no_grad():
571
610
  for i, sample in enumerate(pred_loader):
@@ -574,9 +613,7 @@ class Engine:
574
613
  predicted_augments = []
575
614
  for augmented_pred in augmented_preds:
576
615
  augmented_pred = self.model(augmented_pred.to(self.device))
577
- predicted_augments.append(
578
- augmented_pred.squeeze().cpu().numpy()
579
- )
616
+ predicted_augments.append(augmented_pred.cpu())
580
617
  prediction.append(tta_backward(predicted_augments).squeeze())
581
618
  else:
582
619
  prediction.append(
@@ -584,7 +621,9 @@ class Engine:
584
621
  )
585
622
  progress_bar.update(i, 1)
586
623
  output = denormalize(
587
- np.stack(prediction), float(self.cfg.data.mean), float(self.cfg.data.std) # type: ignore
624
+ np.stack(prediction).squeeze(),
625
+ float(self.cfg.data.mean),
626
+ float(self.cfg.data.std),
588
627
  )
589
628
  return output
590
629
 
@@ -696,14 +735,13 @@ class Engine:
696
735
  )
697
736
 
698
737
  if input is None:
699
- raise ValueError("Ccannot predict on None input.")
738
+ raise ValueError("Input cannot be None.")
700
739
 
701
740
  # Create dataset
702
741
  if isinstance(input, np.ndarray): # np.ndarray
703
- # Check that the axes fit the input
742
+ # Validate axes and add missing dimensions (S)C if necessary
704
743
  img_axes = self.cfg.data.axes if axes is None else axes
705
- # TODO are self.cfg.data.axes and axes compatible (same spatial dim)?
706
- check_array_validity(input, img_axes)
744
+ input_expanded = add_axes(input, img_axes)
707
745
 
708
746
  # Check if tiling requested
709
747
  tiled = tile_shape is not None and overlaps is not None
@@ -714,11 +752,9 @@ class Engine:
714
752
  "Tiling with in memory array is currently not implemented."
715
753
  )
716
754
 
717
- # check_tiling_validity(tile_shape, overlaps)
718
-
719
755
  # Normalize input and cast to float32
720
756
  normalized_input = normalize(
721
- img=input, mean=self.cfg.data.mean, std=self.cfg.data.std
757
+ img=input_expanded, mean=self.cfg.data.mean, std=self.cfg.data.std
722
758
  )
723
759
  normalized_input = normalized_input.astype(np.float32)
724
760
 
@@ -812,7 +848,66 @@ class Engine:
812
848
  self.logger.removeHandler(handler)
813
849
  handler.close()
814
850
 
815
- def _generate_rdf(self, model_specs: Optional[dict] = None) -> dict:
851
+ def _get_sample_io_files(
852
+ self,
853
+ input_array: Optional[np.ndarray] = None,
854
+ axes: Optional[str] = None,
855
+ ) -> Tuple[List[str], List[str]]:
856
+ """
857
+ Create numpy format for use as inputs and outputs in the bioimage.io archive.
858
+
859
+ Parameters
860
+ ----------
861
+ input_array : Optional[np.ndarray], optional
862
+ Input array to use for the bioimage.io model zoo, by default None.
863
+ axes : Optional[str], optional
864
+ Axes from the configuration.
865
+
866
+ Returns
867
+ -------
868
+ Tuple[List[str], List[str]]
869
+ Tuple of input and output file paths.
870
+
871
+ Raises
872
+ ------
873
+ ValueError
874
+ If the configuration is not defined.
875
+ """
876
+ if self.cfg is not None and self._input is not None:
877
+ # use the input array if provided, otherwise use the first validation sample
878
+ if input_array is not None:
879
+ array_in = input_array
880
+
881
+ # add axes to be compatible with the axes declared in the RDF specs
882
+ add_axes(array_in, axes)
883
+ else:
884
+ array_in = self._input
885
+
886
+ # predict (no tta since BMZ does not apply it)
887
+ array_out = self.predict(array_in, tta=False)
888
+
889
+ # add singleton dimensions (for compatibility with model axes)
890
+ # indeed, BMZ applies the model but CAREamics function are meant
891
+ # to work on user data (potentially with no S or C axe)
892
+ array_out = array_out[np.newaxis, np.newaxis, ...]
893
+
894
+ # save numpy files
895
+ workdir = self.cfg.working_directory
896
+ in_file = workdir.joinpath("test_inputs.npy")
897
+ np.save(in_file, array_in)
898
+ out_file = workdir.joinpath("test_outputs.npy")
899
+ np.save(out_file, array_out)
900
+
901
+ return [str(in_file.absolute())], [str(out_file.absolute())]
902
+ else:
903
+ raise ValueError("Configuration is not defined or model was not trained.")
904
+
905
+ def _generate_rdf(
906
+ self,
907
+ *,
908
+ model_specs: Optional[dict] = None,
909
+ input_array: Optional[np.ndarray] = None,
910
+ ) -> dict:
816
911
  """
817
912
  Generate rdf data for bioimage.io format export.
818
913
 
@@ -820,6 +915,8 @@ class Engine:
820
915
  ----------
821
916
  model_specs : Optional[dict], optional
822
917
  Custom specs if different than the default ones, by default None.
918
+ input_array : Optional[np.ndarray], optional
919
+ Input array to use for the bioimage.io model zoo, by default None.
823
920
 
824
921
  Returns
825
922
  -------
@@ -848,7 +945,9 @@ class Engine:
848
945
  axes = "b" + axes
849
946
 
850
947
  # get in/out samples' files
851
- test_inputs, test_outputs = self._get_sample_io_files(axes)
948
+ test_inputs, test_outputs = self._get_sample_io_files(
949
+ input_array, self.cfg.data.axes
950
+ )
852
951
 
853
952
  specs = get_default_model_specs(
854
953
  "Noise2Void",
@@ -861,7 +960,6 @@ class Engine:
861
960
 
862
961
  specs.update(
863
962
  {
864
- "architecture": "careamics.models.unet",
865
963
  "test_inputs": test_inputs,
866
964
  "test_outputs": test_outputs,
867
965
  "input_axes": [axes],
@@ -870,14 +968,21 @@ class Engine:
870
968
  )
871
969
  return specs
872
970
  else:
873
- raise ValueError("Configuration is not defined.")
971
+ raise ValueError("Configuration is not defined or model was not trained.")
874
972
 
875
973
  def save_as_bioimage(
876
- self, output_zip: Union[Path, str], model_specs: Optional[dict] = None
877
- ) -> BioimageModel:
974
+ self,
975
+ output_zip: Union[Path, str],
976
+ model_specs: Optional[dict] = None,
977
+ input_array: Optional[np.ndarray] = None,
978
+ ) -> None:
878
979
  """
879
980
  Export the current model to BioImage.io model zoo format.
880
981
 
982
+ Custom specs can be passed in `model_specs (e.g. maintainers). For a description
983
+ of the model RDF, refer to
984
+ github.com/bioimage-io/spec-bioimage-io/blob/gh-pages/model_spec_latest.md.
985
+
881
986
  Parameters
882
987
  ----------
883
988
  output_zip : Union[Path, str]
@@ -885,11 +990,10 @@ class Engine:
885
990
  model_specs : Optional[dict]
886
991
  A dictionary with keys being the bioimage-core build_model parameters. If
887
992
  None then it will be populated by the model default specs.
888
-
889
- Returns
890
- -------
891
- BioimageModel
892
- Bioimage.io model object.
993
+ input_array : Optional[np.ndarray]
994
+ An array to use as input for the bioimage.io model zoo. If None then the
995
+ first validation sample will be used. Note that the array must have S and
996
+ C dimensions (e.g. SCYX), even if only singleton dimensions.
893
997
 
894
998
  Raises
895
999
  ------
@@ -898,57 +1002,13 @@ class Engine:
898
1002
  """
899
1003
  if self.cfg is not None:
900
1004
  # Generate specs
901
- specs = self._generate_rdf(model_specs)
1005
+ specs = self._generate_rdf(model_specs=model_specs, input_array=input_array)
902
1006
 
903
1007
  # Build model
904
- raw_model = build_zip_model(
1008
+ save_bioimage_model(
905
1009
  path=output_zip,
906
1010
  config=self.cfg,
907
- model_specs=specs,
1011
+ specs=specs,
908
1012
  )
909
-
910
- return raw_model
911
- else:
912
- raise ValueError("Configuration is not defined.")
913
-
914
- def _get_sample_io_files(self, axes: str) -> Tuple[List[str], List[str]]:
915
- """
916
- Create numpy format for use as inputs and outputs in the bioimage.io archive.
917
-
918
- Parameters
919
- ----------
920
- axes : str
921
- Input and output axes.
922
-
923
- Returns
924
- -------
925
- Tuple[List[str], List[str]]
926
- Tuple of input and output file paths.
927
-
928
- Raises
929
- ------
930
- ValueError
931
- If the configuration is not defined.
932
- """
933
- # input:
934
- if self.cfg is not None:
935
- sample_input = np.random.randn(*self.cfg.training.patch_size)
936
- # if there are more input axes (like channel, ...),
937
- # then expand the sample dimensions.
938
- len_diff = len(axes) - len(self.cfg.training.patch_size)
939
- if len_diff > 0:
940
- sample_input = np.expand_dims(
941
- sample_input, axis=tuple(i for i in range(len_diff))
942
- )
943
- sample_output = np.random.randn(*sample_input.shape)
944
-
945
- # save numpy files
946
- workdir = self.cfg.working_directory
947
- in_file = workdir.joinpath("test_inputs.npy")
948
- np.save(in_file, sample_input)
949
- out_file = workdir.joinpath("test_outputs.npy")
950
- np.save(out_file, sample_output)
951
-
952
- return [str(in_file.absolute())], [str(out_file.absolute())]
953
1013
  else:
954
1014
  raise ValueError("Configuration is not defined.")
@@ -5,8 +5,9 @@ This module contains a factory function for creating loss functions.
5
5
  """
6
6
  from typing import Callable
7
7
 
8
- from ..config import Configuration
9
- from ..config.algorithm import Loss
8
+ from careamics.config import Configuration
9
+ from careamics.config.algorithm import Loss
10
+
10
11
  from .losses import n2v_loss
11
12
 
12
13
 
@@ -8,10 +8,11 @@ from typing import Dict, Optional, Tuple, Union
8
8
 
9
9
  import torch
10
10
 
11
- from ..bioimage import import_bioimage_model
12
- from ..config import Configuration
13
- from ..config.algorithm import Models
14
- from ..utils.logging import get_logger
11
+ from careamics.bioimage import import_bioimage_model
12
+ from careamics.config import Configuration
13
+ from careamics.config.algorithm import Models
14
+ from careamics.utils.logging import get_logger
15
+
15
16
  from .unet import UNet
16
17
 
17
18
  logger = get_logger(__name__)
@@ -49,16 +50,27 @@ def create_model(
49
50
  model_path: Optional[Union[str, Path]] = None,
50
51
  config: Optional[Configuration] = None,
51
52
  device: Optional[torch.device] = None,
52
- ) -> torch.nn.Module:
53
+ ) -> Tuple[
54
+ torch.nn.Module,
55
+ torch.optim.Optimizer,
56
+ Union[
57
+ torch.optim.lr_scheduler.LRScheduler,
58
+ torch.optim.lr_scheduler.ReduceLROnPlateau, # not a subclass of LRScheduler
59
+ ],
60
+ torch.cuda.amp.GradScaler,
61
+ Configuration,
62
+ ]:
53
63
  """
54
- Instantiate a model from a checkpoint or configuration.
64
+ Instantiate a model from a model path or configuration.
55
65
 
56
- If both checkpoint and configuration are provided, the checkpoint is used.
66
+ If both path and configuration are provided, the model path is used. The model
67
+ path should point to either a checkpoint (created during training) or a model
68
+ exported to the bioimage.io format.
57
69
 
58
70
  Parameters
59
71
  ----------
60
72
  model_path : Optional[Union[str, Path]], optional
61
- Path to a checkpoint, by default None.
73
+ Path to a checkpoint or bioimage.io archive, by default None.
62
74
  config : Optional[Configuration], optional
63
75
  Configuration, by default None.
64
76
  device : Optional[torch.device], optional
@@ -103,12 +115,13 @@ def create_model(
103
115
  raise ValueError("Invalid checkpoint format, no configuration found.")
104
116
 
105
117
  # Create model
106
- model = model_registry(model_name)(
118
+ model: torch.nn.Module = model_registry(model_name)(
107
119
  depth=model_config.depth,
108
120
  conv_dim=algo_config.get_conv_dim(),
109
121
  num_channels_init=model_config.num_channels_init,
110
122
  )
111
123
  model.to(device)
124
+
112
125
  # Load the model state dict
113
126
  if "model_state_dict" in checkpoint:
114
127
  model.load_state_dict(checkpoint["model_state_dict"])
@@ -141,13 +154,19 @@ def create_model(
141
154
 
142
155
  else:
143
156
  raise ValueError("Either config or model_path must be provided")
144
- # model = compile_model(model)
157
+
145
158
  return model, optimizer, scheduler, scaler, config
146
159
 
147
160
 
148
161
  def get_optimizer_and_scheduler(
149
162
  config: Configuration, model: torch.nn.Module, state_dict: Optional[Dict] = None
150
- ) -> Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler]:
163
+ ) -> Tuple[
164
+ torch.optim.Optimizer,
165
+ Union[
166
+ torch.optim.lr_scheduler.LRScheduler,
167
+ torch.optim.lr_scheduler.ReduceLROnPlateau, # not a subclass of LRScheduler
168
+ ],
169
+ ]:
151
170
  """
152
171
  Create optimizer and learning rate schedulers.
153
172
 
@@ -46,16 +46,18 @@ def stitch_prediction(
46
46
  return predicted_image
47
47
 
48
48
 
49
- def tta_forward(x: np.ndarray) -> List:
49
+ def tta_forward(x: torch.Tensor) -> List[torch.Tensor]:
50
50
  """
51
51
  Augment 8-fold an array.
52
52
 
53
53
  The augmentation is performed using all 90 deg rotations and their flipped version,
54
54
  as well as the original image flipped.
55
55
 
56
+ Tensors should be of shape SC(Z)YX, with S and C potentially singleton dimensions.
57
+
56
58
  Parameters
57
59
  ----------
58
- x : torch.tensor
60
+ x : torch.Tensor
59
61
  Data to augment.
60
62
 
61
63
  Returns
@@ -75,13 +77,15 @@ def tta_forward(x: np.ndarray) -> List:
75
77
  return x_aug_flip
76
78
 
77
79
 
78
- def tta_backward(x_aug: List) -> np.ndarray:
80
+ def tta_backward(x_aug: List[torch.Tensor]) -> np.ndarray:
79
81
  """
80
82
  Invert `tta_forward` and average the 8 images.
81
83
 
84
+ The function takes a list of torch tensors and returns a numpy array.
85
+
82
86
  Parameters
83
87
  ----------
84
- x_aug : List
88
+ x_aug : List[torch.Tensor]
85
89
  Stack of 8-fold augmented images.
86
90
 
87
91
  Returns
@@ -90,13 +94,13 @@ def tta_backward(x_aug: List) -> np.ndarray:
90
94
  Average of de-augmented x_aug.
91
95
  """
92
96
  x_deaug = [
93
- x_aug[0],
94
- np.rot90(x_aug[1], -1),
95
- np.rot90(x_aug[2], -2),
96
- np.rot90(x_aug[3], -3),
97
- np.fliplr(x_aug[4]),
98
- np.rot90(np.fliplr(x_aug[5]), -1),
99
- np.rot90(np.fliplr(x_aug[6]), -2),
100
- np.rot90(np.fliplr(x_aug[7]), -3),
97
+ x_aug[0].numpy(),
98
+ np.rot90(x_aug[1], -1, axes=(2, 3)),
99
+ np.rot90(x_aug[2], -2, axes=(2, 3)),
100
+ np.rot90(x_aug[3], -3, axes=(2, 3)),
101
+ np.flip(x_aug[4].numpy(), axis=(1, 3)),
102
+ np.rot90(np.flip(x_aug[5].numpy(), axis=(1, 3)), -1, axes=(2, 3)),
103
+ np.rot90(np.flip(x_aug[6].numpy(), axis=(1, 3)), -2, axes=(2, 3)),
104
+ np.rot90(np.flip(x_aug[7].numpy(), axis=(1, 3)), -3, axes=(2, 3)),
101
105
  ]
102
106
  return np.mean(x_deaug, 0)
@@ -5,11 +5,10 @@ __all__ = [
5
5
  "denormalize",
6
6
  "normalize",
7
7
  "get_device",
8
- "check_array_validity",
9
8
  "check_axes_validity",
9
+ "add_axes",
10
10
  "check_tiling_validity",
11
11
  "cwd",
12
- "compile_model",
13
12
  "MetricTracker",
14
13
  ]
15
14
 
@@ -17,5 +16,5 @@ __all__ = [
17
16
  from .context import cwd
18
17
  from .metrics import MetricTracker
19
18
  from .normalization import denormalize, normalize
20
- from .torch_utils import compile_model, get_device
21
- from .validators import check_array_validity, check_axes_validity, check_tiling_validity
19
+ from .torch_utils import get_device
20
+ from .validators import add_axes, check_axes_validity, check_tiling_validity