careamics 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (111) hide show
  1. careamics/__init__.py +17 -2
  2. careamics/careamist.py +4 -3
  3. careamics/cli/conf.py +1 -2
  4. careamics/cli/main.py +1 -2
  5. careamics/cli/utils.py +3 -3
  6. careamics/config/__init__.py +47 -25
  7. careamics/config/algorithms/__init__.py +15 -0
  8. careamics/config/algorithms/care_algorithm_model.py +38 -0
  9. careamics/config/algorithms/n2n_algorithm_model.py +30 -0
  10. careamics/config/algorithms/n2v_algorithm_model.py +29 -0
  11. careamics/config/algorithms/unet_algorithm_model.py +88 -0
  12. careamics/config/{vae_algorithm_model.py → algorithms/vae_algorithm_model.py} +14 -12
  13. careamics/config/architectures/__init__.py +1 -11
  14. careamics/config/architectures/architecture_model.py +3 -3
  15. careamics/config/architectures/lvae_model.py +6 -1
  16. careamics/config/architectures/unet_model.py +1 -0
  17. careamics/config/care_configuration.py +100 -0
  18. careamics/config/configuration.py +354 -0
  19. careamics/config/{configuration_factory.py → configuration_factories.py} +185 -57
  20. careamics/config/configuration_io.py +85 -0
  21. careamics/config/data/__init__.py +10 -0
  22. careamics/config/{data_model.py → data/data_model.py} +91 -186
  23. careamics/config/data/n2v_data_model.py +193 -0
  24. careamics/config/likelihood_model.py +1 -2
  25. careamics/config/n2n_configuration.py +101 -0
  26. careamics/config/n2v_configuration.py +266 -0
  27. careamics/config/nm_model.py +1 -2
  28. careamics/config/support/__init__.py +7 -7
  29. careamics/config/support/supported_algorithms.py +5 -4
  30. careamics/config/support/supported_architectures.py +0 -4
  31. careamics/config/transformations/__init__.py +10 -4
  32. careamics/config/transformations/transform_model.py +3 -3
  33. careamics/config/transformations/transform_unions.py +42 -0
  34. careamics/config/validators/__init__.py +12 -1
  35. careamics/config/validators/model_validators.py +84 -0
  36. careamics/config/validators/validator_utils.py +3 -3
  37. careamics/dataset/__init__.py +2 -2
  38. careamics/dataset/dataset_utils/__init__.py +3 -3
  39. careamics/dataset/dataset_utils/dataset_utils.py +4 -6
  40. careamics/dataset/dataset_utils/file_utils.py +9 -9
  41. careamics/dataset/dataset_utils/iterate_over_files.py +4 -3
  42. careamics/dataset/in_memory_dataset.py +11 -12
  43. careamics/dataset/iterable_dataset.py +4 -4
  44. careamics/dataset/iterable_pred_dataset.py +2 -1
  45. careamics/dataset/iterable_tiled_pred_dataset.py +2 -1
  46. careamics/dataset/patching/random_patching.py +11 -10
  47. careamics/dataset/patching/sequential_patching.py +26 -26
  48. careamics/dataset/patching/validate_patch_dimension.py +3 -3
  49. careamics/dataset/tiling/__init__.py +2 -2
  50. careamics/dataset/tiling/collate_tiles.py +3 -3
  51. careamics/dataset/tiling/lvae_tiled_patching.py +2 -1
  52. careamics/dataset/tiling/tiled_patching.py +11 -10
  53. careamics/file_io/__init__.py +5 -5
  54. careamics/file_io/read/__init__.py +1 -1
  55. careamics/file_io/read/get_func.py +2 -2
  56. careamics/file_io/write/__init__.py +2 -2
  57. careamics/lightning/__init__.py +5 -5
  58. careamics/lightning/callbacks/__init__.py +1 -1
  59. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +3 -3
  60. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +2 -1
  61. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +2 -1
  62. careamics/lightning/callbacks/progress_bar_callback.py +3 -3
  63. careamics/lightning/lightning_module.py +11 -7
  64. careamics/lightning/train_data_module.py +36 -45
  65. careamics/losses/__init__.py +3 -3
  66. careamics/lvae_training/calibration.py +64 -57
  67. careamics/lvae_training/dataset/lc_dataset.py +2 -1
  68. careamics/lvae_training/dataset/multich_dataset.py +2 -2
  69. careamics/lvae_training/dataset/types.py +1 -1
  70. careamics/lvae_training/eval_utils.py +123 -128
  71. careamics/model_io/__init__.py +1 -1
  72. careamics/model_io/bioimage/__init__.py +1 -1
  73. careamics/model_io/bioimage/_readme_factory.py +1 -1
  74. careamics/model_io/bioimage/model_description.py +17 -17
  75. careamics/model_io/bmz_io.py +6 -17
  76. careamics/model_io/model_io_utils.py +9 -9
  77. careamics/models/layers.py +16 -16
  78. careamics/models/lvae/likelihoods.py +2 -0
  79. careamics/models/lvae/lvae.py +13 -4
  80. careamics/models/lvae/noise_models.py +280 -217
  81. careamics/models/lvae/stochastic.py +1 -0
  82. careamics/models/model_factory.py +2 -15
  83. careamics/models/unet.py +8 -8
  84. careamics/prediction_utils/__init__.py +1 -1
  85. careamics/prediction_utils/prediction_outputs.py +15 -15
  86. careamics/prediction_utils/stitch_prediction.py +6 -6
  87. careamics/transforms/__init__.py +5 -5
  88. careamics/transforms/compose.py +13 -13
  89. careamics/transforms/n2v_manipulate.py +3 -3
  90. careamics/transforms/pixel_manipulation.py +9 -9
  91. careamics/transforms/xy_random_rotate90.py +4 -4
  92. careamics/utils/__init__.py +5 -5
  93. careamics/utils/context.py +2 -1
  94. careamics/utils/logging.py +11 -10
  95. careamics/utils/metrics.py +25 -0
  96. careamics/utils/plotting.py +78 -0
  97. careamics/utils/torch_utils.py +7 -7
  98. {careamics-0.0.5.dist-info → careamics-0.0.7.dist-info}/METADATA +13 -11
  99. careamics-0.0.7.dist-info/RECORD +178 -0
  100. careamics/config/architectures/custom_model.py +0 -162
  101. careamics/config/architectures/register_model.py +0 -103
  102. careamics/config/configuration_model.py +0 -603
  103. careamics/config/fcn_algorithm_model.py +0 -152
  104. careamics/config/references/__init__.py +0 -45
  105. careamics/config/references/algorithm_descriptions.py +0 -132
  106. careamics/config/references/references.py +0 -39
  107. careamics/config/transformations/transform_union.py +0 -20
  108. careamics-0.0.5.dist-info/RECORD +0 -171
  109. {careamics-0.0.5.dist-info → careamics-0.0.7.dist-info}/WHEEL +0 -0
  110. {careamics-0.0.5.dist-info → careamics-0.0.7.dist-info}/entry_points.txt +0 -0
  111. {careamics-0.0.5.dist-info → careamics-0.0.7.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,78 @@
1
+ """Plotting utilities."""
2
+
3
+ from typing import Optional
4
+
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ import torch
8
+ from numpy.typing import NDArray
9
+
10
+ from careamics.models.lvae.noise_models import GaussianMixtureNoiseModel
11
+
12
+
13
+ def plot_noise_model_probability_distribution(
14
+ noise_model: GaussianMixtureNoiseModel,
15
+ signalBinIndex: int,
16
+ histogram: NDArray,
17
+ channel: Optional[str] = None,
18
+ number_of_bins: int = 100,
19
+ ) -> None:
20
+ """Plot probability distribution P(x|s) for a certain ground truth signal.
21
+
22
+ Predictions from both Histogram and GMM-based
23
+ Noise models are displayed for comparison.
24
+
25
+ Parameters
26
+ ----------
27
+ noise_model : GaussianMixtureNoiseModel
28
+ Trained GaussianMixtureNoiseModel.
29
+ signalBinIndex : int
30
+ Index of signal bin. Values go from 0 to number of bins (`n_bin`).
31
+ histogram : NDArray
32
+ Histogram based noise model.
33
+ channel : Optional[str], optional
34
+ Channel name used for plotting. Default is None.
35
+ number_of_bins : int, optional
36
+ Number of bins in the resulting histogram. Default is 100.
37
+ """
38
+ min_signal = noise_model.min_signal.item()
39
+ max_signal = noise_model.max_signal.item()
40
+ bin_size = (max_signal - min_signal) / number_of_bins
41
+
42
+ query_signal_normalized = signalBinIndex / number_of_bins
43
+ query_signal = query_signal_normalized * (max_signal - min_signal) + min_signal
44
+ query_signal += bin_size / 2
45
+ query_signal = torch.tensor(query_signal)
46
+
47
+ query_observations = torch.arange(min_signal, max_signal, bin_size)
48
+ query_observations += bin_size / 2
49
+
50
+ likelihoods = noise_model.likelihood(
51
+ observations=query_observations, signals=query_signal
52
+ ).numpy()
53
+
54
+ plt.figure(figsize=(12, 5))
55
+ if channel:
56
+ plt.suptitle(f"Noise model for channel {channel}")
57
+ else:
58
+ plt.suptitle("Noise model")
59
+
60
+ plt.subplot(1, 2, 1)
61
+ plt.xlabel("Observation Bin")
62
+ plt.ylabel("Signal Bin")
63
+ plt.imshow(histogram**0.25, cmap="gray")
64
+ plt.axhline(y=signalBinIndex + 0.5, linewidth=5, color="blue", alpha=0.5)
65
+
66
+ plt.subplot(1, 2, 2)
67
+ plt.plot(
68
+ query_observations,
69
+ likelihoods,
70
+ label="GMM : " + " signal = " + str(np.round(query_signal, 2)),
71
+ marker=".",
72
+ color="red",
73
+ linewidth=2,
74
+ )
75
+ plt.xlabel("Observations (x) for signal s = " + str(query_signal))
76
+ plt.ylabel("Probability Density")
77
+ plt.title("Probability Distribution P(x|s) at signal =" + str(query_signal))
78
+ plt.legend()
@@ -5,7 +5,7 @@ These functions are used to control certain aspects and behaviours of PyTorch.
5
5
  """
6
6
 
7
7
  import inspect
8
- from typing import Dict, Union
8
+ from typing import Union
9
9
 
10
10
  import torch
11
11
 
@@ -27,12 +27,12 @@ def filter_parameters(
27
27
  ----------
28
28
  func : type
29
29
  Class object.
30
- user_params : Dict
30
+ user_params : dict
31
31
  User provided parameters.
32
32
 
33
33
  Returns
34
34
  -------
35
- Dict
35
+ dict
36
36
  Parameters matching `func`'s signature.
37
37
  """
38
38
  # Get the list of all default parameters
@@ -64,13 +64,13 @@ def get_optimizer(name: str) -> torch.optim.Optimizer:
64
64
  return getattr(torch.optim, name)
65
65
 
66
66
 
67
- def get_optimizers() -> Dict[str, str]:
67
+ def get_optimizers() -> dict[str, str]:
68
68
  """
69
69
  Return the list of all optimizers available in torch.optim.
70
70
 
71
71
  Returns
72
72
  -------
73
- Dict
73
+ dict
74
74
  Optimizers available in torch.optim.
75
75
  """
76
76
  optims = {}
@@ -106,13 +106,13 @@ def get_scheduler(
106
106
  return getattr(torch.optim.lr_scheduler, name)
107
107
 
108
108
 
109
- def get_schedulers() -> Dict[str, str]:
109
+ def get_schedulers() -> dict[str, str]:
110
110
  """
111
111
  Return the list of all schedulers available in torch.optim.lr_scheduler.
112
112
 
113
113
  Returns
114
114
  -------
115
- Dict
115
+ dict
116
116
  Schedulers available in torch.optim.lr_scheduler.
117
117
  """
118
118
  schedulers = {}
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: careamics
3
- Version: 0.0.5
3
+ Version: 0.0.7
4
4
  Summary: Toolbox for running N2V and friends.
5
5
  Project-URL: homepage, https://careamics.github.io/
6
6
  Project-URL: repository, https://github.com/CAREamics/careamics
@@ -17,17 +17,19 @@ Classifier: Programming Language :: Python :: 3.12
17
17
  Classifier: Typing :: Typed
18
18
  Requires-Python: >=3.9
19
19
  Requires-Dist: bioimageio-core==0.7
20
+ Requires-Dist: matplotlib<=3.10.0
20
21
  Requires-Dist: numpy<2.0.0
21
- Requires-Dist: pillow<=10.3.0
22
- Requires-Dist: psutil<=6.1
23
- Requires-Dist: pydantic<2.9,>=2.5
24
- Requires-Dist: pytorch-lightning<=2.4,>=2.2
22
+ Requires-Dist: pillow<=11.1.0
23
+ Requires-Dist: psutil<=6.1.1
24
+ Requires-Dist: pydantic<2.11,>=2.5
25
+ Requires-Dist: pytorch-lightning<=2.5.0.post0,>=2.2
25
26
  Requires-Dist: pyyaml!=6.0.0,<=6.0.2
26
- Requires-Dist: scikit-image<=0.23.2
27
- Requires-Dist: tifffile<=2024.8.30
28
- Requires-Dist: torch<=2.5,>=2.0
29
- Requires-Dist: torchvision<=0.20
30
- Requires-Dist: typer==0.12.3
27
+ Requires-Dist: scikit-image<=0.25.1
28
+ Requires-Dist: tifffile<=2025.1.10
29
+ Requires-Dist: torch<=2.6.0,>=2.0
30
+ Requires-Dist: torchvision<=0.20.1
31
+ Requires-Dist: torchvision<=0.21.0
32
+ Requires-Dist: typer<=0.15.1,>=0.12.3
31
33
  Requires-Dist: zarr<3.0.0
32
34
  Provides-Extra: dev
33
35
  Requires-Dist: onnx; extra == 'dev'
@@ -40,7 +42,7 @@ Requires-Dist: careamics-portfolio; extra == 'examples'
40
42
  Requires-Dist: jupyter; extra == 'examples'
41
43
  Requires-Dist: matplotlib; extra == 'examples'
42
44
  Provides-Extra: tensorboard
43
- Requires-Dist: protobuf==3.20.3; extra == 'tensorboard'
45
+ Requires-Dist: protobuf==5.29.1; extra == 'tensorboard'
44
46
  Requires-Dist: tensorboard; extra == 'tensorboard'
45
47
  Provides-Extra: wandb
46
48
  Requires-Dist: wandb; extra == 'wandb'
@@ -0,0 +1,178 @@
1
+ careamics/__init__.py,sha256=WF2JpQmC-MmuSB0L81XRo67NwaN_0qjyywcpRlbVJVE,569
2
+ careamics/careamist.py,sha256=rakSaDSGRR0Pr0o1s8ejwIrjWknOQs4DgljukqfyWu0,37635
3
+ careamics/conftest.py,sha256=Od4WcaaP0UP-XUMrFr_oo4e6c2hi_RvNbuaRTopwlmI,911
4
+ careamics/py.typed,sha256=esB4cHc6c07uVkGtqf8at7ttEnprwRxwk8obY8Qumq4,187
5
+ careamics/cli/__init__.py,sha256=LbM9bVtU1dy-khmdiIDXwvKy2v8wPBCEUuWqV_8rosA,106
6
+ careamics/cli/conf.py,sha256=oixGRZNylW-NTM_rkDtQkSRw8KUYwtmUC_hK5BEeLnA,13074
7
+ careamics/cli/main.py,sha256=S4B3c1ZN-OQK0l2_W42CaW0KmF_Pe_y4pKgn_UOuyDg,6564
8
+ careamics/cli/utils.py,sha256=q_dmG7lxg_FT62qX9fPilIWL1M8ibhLnnhUKqa4knPI,660
9
+ careamics/config/__init__.py,sha256=c5FXtcQrtROHdLRl4cHpKo6V4_E4yr6HxYgNwqH9CHg,1917
10
+ careamics/config/callback_model.py,sha256=EeYHqpMIPQwyNxLRzzX32Uncl5mZuB1bJO76RHpNymg,4555
11
+ careamics/config/care_configuration.py,sha256=hSfNJ-dooHm4ujG6Q3Hawr8zDeYy1HNiebtM7gxsh7s,2714
12
+ careamics/config/configuration.py,sha256=KmLeXHkFhQTrcru1erhfVf3tHvQdoi12ls_u254rtDw,11114
13
+ careamics/config/configuration_factories.py,sha256=9civH9r1yfXcYeXJS61ft3Wsn8PoODNeqzGU45CTFCs,35460
14
+ careamics/config/configuration_io.py,sha256=ks9R8lRCBY_m0sdy1k4ZWFPKEFSp7K9X47jCG4d0FY4,2353
15
+ careamics/config/inference_model.py,sha256=UE_-ZmCX6LFCbDBOwyGnvuAboF_JNX2m2LcF0WiwgCI,6961
16
+ careamics/config/likelihood_model.py,sha256=VorUtc0_-xIWNxwVrd1kBba-003ICdVMtxpcDCxH4Io,2259
17
+ careamics/config/loss_model.py,sha256=yYcUBS90Qyon1MxeaHiVP3dJHPJFC0GUvWKGcAb3IHk,2036
18
+ careamics/config/n2n_configuration.py,sha256=QoHnEak0LiG6rBIXnmx61Chkz-Q0jr2Qt9I0NnmgZo4,2667
19
+ careamics/config/n2v_configuration.py,sha256=5MagUFaVZ-9VGIiaKD-4RRkFpJen2imMoR3Lu07BD14,8489
20
+ careamics/config/nm_model.py,sha256=5dAhDBLa4WPfKaNEK6ATNsSUwtlH8u8gYweEA4gZP6g,4758
21
+ careamics/config/optimizer_models.py,sha256=OWpTydRBBR8wt_af1mZHNNwvL_RtnRFopAOdgjzLo30,5750
22
+ careamics/config/tile_information.py,sha256=c-_xrVPOgcnjiEzQ-9A_GhNPamObkMANbeHaRP29R-4,2059
23
+ careamics/config/training_model.py,sha256=67_ipo_-LxhT4-WqAs40Sg8PjU--my43Qn3BhjvlXxM,3212
24
+ careamics/config/algorithms/__init__.py,sha256=on5D6zBO9Lu-Tf5To8xpF6owIFqdN7RSmZdpyXDOaNw,404
25
+ careamics/config/algorithms/care_algorithm_model.py,sha256=-AEfzrA4HugYwMdDHqcc_i1H9kycXKS0YyRwT_MCCPo,951
26
+ careamics/config/algorithms/n2n_algorithm_model.py,sha256=aWG6-YYB9T2e-QirxC-YijosO3QNA0rJRPIQZrorSi0,799
27
+ careamics/config/algorithms/n2v_algorithm_model.py,sha256=b1M1ab9D8rhCG7RmhmaDi5rjHJjaQXebTdUEewJBnNg,709
28
+ careamics/config/algorithms/unet_algorithm_model.py,sha256=OaBFVlhsb9YhF3f2x1ImazvfnZ4_DPWvYWihRwurkeg,2587
29
+ careamics/config/algorithms/vae_algorithm_model.py,sha256=1ZrShUGHA7zbjSiCQwhUSv7l7Hr7MbpcLOw8ucM27p8,4680
30
+ careamics/config/architectures/__init__.py,sha256=lYUz56R7LDqVQWQDLkLgJL8VtOyxc_ege1r4bXGEBqA,220
31
+ careamics/config/architectures/architecture_model.py,sha256=gXn4gdLrQP3bmTQxIhzkEHYlodPaIp6LI-kwZl23W-Y,911
32
+ careamics/config/architectures/lvae_model.py,sha256=lwOiJYNUqPrZFl9SpPLYon77EiRbe2eI7pmpx45rO78,7606
33
+ careamics/config/architectures/unet_model.py,sha256=HJJWf-wuYTv8KXMwikdjeB8htg1QOt0IyQbMuBt1LUI,3556
34
+ careamics/config/data/__init__.py,sha256=ijFNSRKkKVF8fw6ym8kXq_wNhdwkWXxuNj2XKtD3KjE,218
35
+ careamics/config/data/data_model.py,sha256=Z-sy6bJ_JYT5fzbVwLOH1PnjRkVY0YAmNG5OnQAd-SQ,12635
36
+ careamics/config/data/n2v_data_model.py,sha256=-n5cncmcrd4-KUSmEk8rXAONMC7sfbYfGmshVNuySCU,5915
37
+ careamics/config/support/__init__.py,sha256=ktWvxbTkRXnQPS_N84l9E2B5kTZVdd64SIjsJIQKB-k,1041
38
+ careamics/config/support/supported_activations.py,sha256=CqOWoziIK5jZZXJO7G7cGg3TTid1POqv8FXqxjXxyME,535
39
+ careamics/config/support/supported_algorithms.py,sha256=w6YzcIqGZ_bS85Tw1s7TEltBDXLt4SzgN3Tc6s19dGU,946
40
+ careamics/config/support/supported_architectures.py,sha256=pOxvHOAIUkc7HCO0IIg4K22h-Ti5ErtcIkGOjN-zh1s,340
41
+ careamics/config/support/supported_data.py,sha256=T_mDiWLFMVji_EpjBABUObAJcnv-XBnqp9XUZP37Tdk,2902
42
+ careamics/config/support/supported_loggers.py,sha256=ubSOkGoYabGbm_jmyc1R3eFcvcP-sHmuyiBi_d3_wLg,197
43
+ careamics/config/support/supported_losses.py,sha256=2x5sZuxRbWJzodoL35I1mMYUUDMzk8UFiFdbyPwbJ4E,583
44
+ careamics/config/support/supported_optimizers.py,sha256=_2XmwzYENB6xpTedyWHUdWuGcDzdlfEAJjzm_qI3yRM,1392
45
+ careamics/config/support/supported_pixel_manipulations.py,sha256=rFiktUlvoFU7s1NAKEMqsXOzLw5eaw9GtCKUznvq6xc,432
46
+ careamics/config/support/supported_struct_axis.py,sha256=alZMA5Y-BpDymLPUEd1zqVY0xMkgl9Rv1d4ujED6sco,424
47
+ careamics/config/support/supported_transforms.py,sha256=ODvmoTywvJWG_5-SJJZu-X1FNtKGhkNWQc-t26IFZWI,311
48
+ careamics/config/transformations/__init__.py,sha256=jMTUX15n8ZF4Nc9gQp-qbXfhj-iEsEa65lzNbHwzyfY,631
49
+ careamics/config/transformations/n2v_manipulate_model.py,sha256=Mdxc4J3vxe_dM2CIhmTwwGOIirQvrQXLoa2vRsTzoYI,1855
50
+ careamics/config/transformations/normalize_model.py,sha256=1Rkk6IkF-7ytGU6HSzP-TpOi4RRWiQJ6fOd8zammXcg,1936
51
+ careamics/config/transformations/transform_model.py,sha256=6UVbXnxm-LLZOQQ-ZBwWwgmS_99DiBuERLfMxrta3-8,990
52
+ careamics/config/transformations/transform_unions.py,sha256=uqlI8Nm827bKfMbDQLVhKFtT9e7TJ_zIYDBdHlOuQ1I,1137
53
+ careamics/config/transformations/xy_flip_model.py,sha256=zU-uZ1b1zNZWckbho3onN-B7BHKhN7jbgbNZyRQhv2s,1025
54
+ careamics/config/transformations/xy_random_rotate90_model.py,sha256=6sYKmtCLvz0SV1qZgBSHUTH-CUjwvHnohq1HyPntbyE,894
55
+ careamics/config/validators/__init__.py,sha256=zRrIse0O3ImwG97NphEupFVm3Ib9nEJhpNNrKGyDTps,423
56
+ careamics/config/validators/model_validators.py,sha256=9OCdlf7rmndtTpmQ8COLaEjURYYmszic_RjY9mzS-k4,1941
57
+ careamics/config/validators/validator_utils.py,sha256=NVkEOr5AQK4JXWNtmgeQgAaJOyieJNb5PHCjlcqNeew,2611
58
+ careamics/dataset/__init__.py,sha256=31vop67zbtGesENEIig-LLw1q2lCydMFc_YWgfK2Yt4,547
59
+ careamics/dataset/in_memory_dataset.py,sha256=MV_Vf4siIP-g7VKhxN4rU7MZXpaHKvfwr8ZXqk44Qhs,9958
60
+ careamics/dataset/in_memory_pred_dataset.py,sha256=VvwW5D8TjgO_kR8eZinP-9qepSiI6ZsUN7FZ0Rvc8Bs,2161
61
+ careamics/dataset/in_memory_tiled_pred_dataset.py,sha256=DANmlnlV1ysXKdwGvmJoOYKcjlgoMhnSGSDRpeK79ZA,3552
62
+ careamics/dataset/iterable_dataset.py,sha256=pqtm-AWhDbuZTnXf0roAHWVxGPRTekAzVcJHLzaSyFU,9797
63
+ careamics/dataset/iterable_pred_dataset.py,sha256=jee4b8bZOyvSS5qfIsb6Jkk1EV_MKEU2SyZ0m7p0p9k,3767
64
+ careamics/dataset/iterable_tiled_pred_dataset.py,sha256=2j_kLMB6DfSKXPszZPYgsB08TVgcf1V5HY_kZVozrFM,4560
65
+ careamics/dataset/zarr_dataset.py,sha256=lojnK5bhiF1vyjuPtWXBrZ9sy5fT_rBvZJbbbnE-H_I,5665
66
+ careamics/dataset/dataset_utils/__init__.py,sha256=MJ3xriL6R4ZtmzbvLsASUWLb85Hk5AdeRaYnHpNELJQ,507
67
+ careamics/dataset/dataset_utils/dataset_utils.py,sha256=X83DzaOWmHdl4eOPac2IQJH3bPA43RVq0hPrFrzvIXQ,2630
68
+ careamics/dataset/dataset_utils/file_utils.py,sha256=ru6AtQ9LCmo6raN1-GnJEN4UyP1PbmSdR9MEys3CuHo,4094
69
+ careamics/dataset/dataset_utils/iterate_over_files.py,sha256=Jun35Qn9XevHOb_DixYBMHDAOykLmiwciA5Q2MzSUK8,2912
70
+ careamics/dataset/dataset_utils/running_stats.py,sha256=kWorioMH4S5uZj2cvUpjHB6cIUhMFa1XXwDQrrKIWdI,5752
71
+ careamics/dataset/patching/__init__.py,sha256=7-s12oUAZNlMOwSkxSwbD7vojQINWYFzn_4qIJ87WBg,37
72
+ careamics/dataset/patching/patching.py,sha256=deAxY34Iz-mguBlHQ-5EO4vRhPpR9I3LQ9onV1K_KqA,8858
73
+ careamics/dataset/patching/random_patching.py,sha256=gm1jxye9yvHbdijLzCtDSzRU_9j110GRLMnJaUwLAHQ,6487
74
+ careamics/dataset/patching/sequential_patching.py,sha256=4F5E1Ta0M5kFXGwI2-QXRxeOx0CyUwbFaB5awkMCN_Q,5890
75
+ careamics/dataset/patching/validate_patch_dimension.py,sha256=mC2bZWBpU44NEvXxEfR7ULUKwWPuZPjmBWpHYJxNDWc,2121
76
+ careamics/dataset/tiling/__init__.py,sha256=aW_AMB9rzm0VmooUpjcyqv6sQP69RlPQMEdP2sVjdz8,190
77
+ careamics/dataset/tiling/collate_tiles.py,sha256=XK0BsDQE7XwIwmOoCHJIpVC3kqjSN6nDhrJ4POVeHS8,965
78
+ careamics/dataset/tiling/lvae_tiled_patching.py,sha256=LYEEdjKuKaxIGFtOkhfpsE7hruBnIsD5HcW9aVH6WHI,13019
79
+ careamics/dataset/tiling/tiled_patching.py,sha256=6vxsqlccUqIl4Ys92JWIPs0Kn95VzaHoAYMSGcp2dh8,5956
80
+ careamics/file_io/__init__.py,sha256=vgMI77X820VOWywAEW5W20FXfmbqBzx4V63D3V3_HhI,334
81
+ careamics/file_io/read/__init__.py,sha256=wf8O_o80ghrlWQ-RGEuSqcc2LU55P1B-oxTacDToygo,259
82
+ careamics/file_io/read/get_func.py,sha256=O_pdymjh2mc-JZ1je3ZnPAcsHc7Je3a005AMgAa0xuw,1388
83
+ careamics/file_io/read/tiff.py,sha256=UMofW33rvByK9B1zYGhSrWAiAA3uQUV3OVK7cq9d0gQ,1359
84
+ careamics/file_io/read/zarr.py,sha256=2jzREAnJDQSv0qmsL-v00BxmiZ_sp0ijq667LZSQ_hY,1685
85
+ careamics/file_io/write/__init__.py,sha256=CUt33cRjG9hm18L9a7XqaUKWQ_3xiuQ9ztz4Ab7RYG0,283
86
+ careamics/file_io/write/get_func.py,sha256=hyGHe1RX-lfa9QFAnwRCz_gS0NRiRnXEtg4Bdeh2Esc,1627
87
+ careamics/file_io/write/tiff.py,sha256=tBGIgl-I1sMyBivgx-dOTBykXBODkgwPH8MT3_4KAE8,1050
88
+ careamics/lightning/__init__.py,sha256=ATCVAGnX08Ik4TxbIv0-cXb52UinR42JgvZh_GIMSpc,588
89
+ careamics/lightning/lightning_module.py,sha256=3nzRcjv7dlWmWjoWmGaBV7D6d698xMp3XO7fpIrZhwA,22630
90
+ careamics/lightning/predict_data_module.py,sha256=JNwujK6QwObSx6P25ghpGl2f2gGT3KVgYMTlonZzH20,12745
91
+ careamics/lightning/train_data_module.py,sha256=LeTyjNtAJw8nNiw2k6Ifuw0fAgppyZyRxyEGbDq30Fo,28309
92
+ careamics/lightning/callbacks/__init__.py,sha256=eA5ltzYNzuO0uMEr1jG4wP01b0s29s5I03WGJ290qkw,312
93
+ careamics/lightning/callbacks/hyperparameters_callback.py,sha256=u45knOZHwoVHz6yYfrnERQuozT_SfZ1OrKP0QjeU4EM,1495
94
+ careamics/lightning/callbacks/progress_bar_callback.py,sha256=w-j_nk2ysyc4THKfwWbpkiKGeqNUpLGtm-8dYBgla2c,2443
95
+ careamics/lightning/callbacks/prediction_writer_callback/__init__.py,sha256=ZVf3vaSU_NjSjrKbI24H0kK9WAiP9oKXfhP670EaWMo,548
96
+ careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py,sha256=i4vGGiVLslafi-5iuvkAKzBgZ0BpwTTxSTo31oViFz4,1480
97
+ careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py,sha256=8HHUSKcG7G0FSCVPnpGQHLfpara5mnKAwsiiyWp2wzo,8210
98
+ careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py,sha256=lxsLjLskRpYnzdyWCdOICUJxF9YzuUi1RH0LJnOCVgo,12594
99
+ careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py,sha256=F1IpbNNgkv5eK8Xpqp7wqv2lsqEdP1wMRlBL7RBn93U,7114
100
+ careamics/losses/__init__.py,sha256=nSWbkBcFhkyUkIT2wVcULqpieyY2Oro39NXZTtfQpXo,351
101
+ careamics/losses/loss_factory.py,sha256=oPacrkwiabsmiW_r--IxX-XPRbzezZUvOuWKbUw5LiI,1518
102
+ careamics/losses/fcn/__init__.py,sha256=kf92MKFGHr6upiztZVgWwtGPf734DZyub92Rn8uEq8o,18
103
+ careamics/losses/fcn/losses.py,sha256=NdOz29hzJ7D26p13q-g0NWoYwNauIWrP2xWww6YPbB8,2360
104
+ careamics/losses/lvae/__init__.py,sha256=0FNtMLHrOMfagtWkaBdz1NTjyf2y0QLgysxJv5jq5uw,19
105
+ careamics/losses/lvae/loss_utils.py,sha256=QxzA2N1TglR4H0X0uyTWWytDagE1lA9IB_TK1lms3ao,2720
106
+ careamics/losses/lvae/losses.py,sha256=wHT1dx04BZ_OI-_S7cFQ5hFmMetm6FSnuZfwZBBtIpY,17977
107
+ careamics/lvae_training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
108
+ careamics/lvae_training/calibration.py,sha256=xHbiLcY2csYos3s7rRSqp7P7G-9wzULcSo1JfVzfIjE,7239
109
+ careamics/lvae_training/eval_utils.py,sha256=7N1thslU4IU1lM1tGg3-wa8AFf5_R2lOSQ7ZZ91AUII,30030
110
+ careamics/lvae_training/get_config.py,sha256=dwVfaQS7nzjQss0E1gGLUpQpjPcOWwLgIhbu3Z0I1rg,3068
111
+ careamics/lvae_training/lightning_module.py,sha256=ryr7iHqCMzCl5esi6_gEcnKFDQkMrw0EXK9Zfgv1Nek,27186
112
+ careamics/lvae_training/metrics.py,sha256=KTDAKhe3vh-YxzGibjtkIG2nnUyujbnwqX4xGwaRXwE,6718
113
+ careamics/lvae_training/train_lvae.py,sha256=lJEBlBGdISVkZBcEnPNRYgJ7VbapYzZHRaFOrZ0xYGE,11080
114
+ careamics/lvae_training/train_utils.py,sha256=e-d4QsF-li8MmAPkAmB1daHpkuU16nBTnQFZYqpTjn4,3567
115
+ careamics/lvae_training/dataset/__init__.py,sha256=dvdHHaRA9ZfOt_uOnXkYyra2_b0Wsxs8qmrze6zxJAE,377
116
+ careamics/lvae_training/dataset/config.py,sha256=hGIggj5uOZrFBK54o9vii0sG5WGhF_E32URKIIzQMec,4342
117
+ careamics/lvae_training/dataset/lc_dataset.py,sha256=r4PffRXzuTJ0tLWei4B3wq6f1Q34raaZQzZ0IQXi8OI,10762
118
+ careamics/lvae_training/dataset/multich_dataset.py,sha256=5yMC6bgEIYHBsjFj5gXlc68xJQz8A05TYbYfOo-TdUQ,41672
119
+ careamics/lvae_training/dataset/multifile_dataset.py,sha256=hJBs6iBrf_FcyUYzg8rDjvKEICHxDYyXVOj-5L0F6FE,10273
120
+ careamics/lvae_training/dataset/types.py,sha256=SQ99hV9R3iwrRLJs-aRkL3OlmrWWkCrca2JqkntoWZs,633
121
+ careamics/lvae_training/dataset/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
122
+ careamics/lvae_training/dataset/utils/data_utils.py,sha256=8PvRPqSbYHPCl87cycZHxXIFOT_EoBV-8XCt3ZLh36s,3125
123
+ careamics/lvae_training/dataset/utils/empty_patch_fetcher.py,sha256=OFjeqhZ6vFULJsF5tnoByhEhE8aLHujFToU_yyqMCP4,2266
124
+ careamics/lvae_training/dataset/utils/index_manager.py,sha256=Gt1I-7lBaQDBgqguOmofAFDdAsJQfz7ktvq4_I80F9c,10084
125
+ careamics/lvae_training/dataset/utils/index_switcher.py,sha256=ZoMi8LsaIkm8MFqIFaxN4oQGyzCcwOlCom8SYNus15E,6716
126
+ careamics/model_io/__init__.py,sha256=khMIkk107LL5JGze0OVfl5Lfi14R3_e4W21tW0iJ1kE,155
127
+ careamics/model_io/bmz_io.py,sha256=0fL0cz_80nsu_Jgk-ImXKTNbn_WSdrXB9flvLm2g3Is,7792
128
+ careamics/model_io/model_io_utils.py,sha256=HCSoNvaOo55kI7teZleN57riqNK2fLq77zMadBKXCyc,2777
129
+ careamics/model_io/bioimage/__init__.py,sha256=dKm-UluVwa_7EHQB0ukx-Qk8JVWtuT-OUinY8hE9EIw,298
130
+ careamics/model_io/bioimage/_readme_factory.py,sha256=sRfHbuwhfYmpwsG0keXFCSK3qCS4VI4GQhX6OhSKLjY,3440
131
+ careamics/model_io/bioimage/bioimage_utils.py,sha256=YVr75SDfafOiuYGonPbcsO-xVS0wI1WkmWQQZx6DXYQ,1246
132
+ careamics/model_io/bioimage/cover_factory.py,sha256=8URrpEfJvdHBJeSrh5H2IQHSUybsTyAOR3_A-YYAAlw,4583
133
+ careamics/model_io/bioimage/model_description.py,sha256=HwjWT77nIt6LkLvIBRxwgPOsL0ffrCMtDUjim39qgqE,10071
134
+ careamics/models/__init__.py,sha256=Xui2BLJd1I2r_E3Sj24fJALFTi2FGtfNscUWj_0c9Hk,93
135
+ careamics/models/activation.py,sha256=nu3sDgd7Lsyw8rvmUwxNN-7SM09cEMfxZ9DRDzdSKns,1049
136
+ careamics/models/layers.py,sha256=tpsxbolRWYycZGxS3hKlDRtMf6HpNdZs98uwx5K8lls,13757
137
+ careamics/models/model_factory.py,sha256=GWbouERvEfHj_BrpKHYgrPAj8dpdoh1R-X--Jt09N1Q,1370
138
+ careamics/models/unet.py,sha256=9m8GxsTXX9c0mC-eVe2ZXQn2afose1CG6Z8vIhELb7I,14308
139
+ careamics/models/lvae/__init__.py,sha256=6dT6uqgT__V08EjoTGxXguTbTkySZmByS9J2Bj6WWLM,53
140
+ careamics/models/lvae/layers.py,sha256=UPxZiZjgnBnPs_wdSzcP-_s17MEw4P1CoIZzn4OdUA0,57944
141
+ careamics/models/lvae/likelihoods.py,sha256=qRYRewQv6PqzJO-7nDnFkK86-R8dI4HSp1_ilRSc2I4,12233
142
+ careamics/models/lvae/lvae.py,sha256=Jlw3mxVCxMDtjMvBWI9C9javHHyrngm8RfTYPsYhbI4,34767
143
+ careamics/models/lvae/noise_models.py,sha256=lpSygXsJmD_erP0V72u9i5CX51wpopLNCH_YjmEL29s,24095
144
+ careamics/models/lvae/stochastic.py,sha256=wiTrLBSYwOvsF1araKxUHy1CHp1mdH9bazctVo0NchA,16628
145
+ careamics/models/lvae/utils.py,sha256=EE3paHu3vhCaqfOrGypzUsImZJO94uBhx8q6kZ-R36o,11516
146
+ careamics/prediction_utils/__init__.py,sha256=k5hsPGY8FOkwIT0fQgrUz7fVCH2NlwuOdZiISdXjEWg,270
147
+ careamics/prediction_utils/lvae_prediction.py,sha256=ZwPFCSeUGsULIMoMQWRYKHfLFaDm7UKyGaUMVfSUqfs,6210
148
+ careamics/prediction_utils/lvae_tiling_manager.py,sha256=SI-JaJvLrKWBSHdm-FjcqWdbhlcflTRiKxYF7CSGzvA,13736
149
+ careamics/prediction_utils/prediction_outputs.py,sha256=fw-bJ2szWJD7BgZlECmxy5sgeXGFJl4T8cRNzLR1aUQ,4069
150
+ careamics/prediction_utils/stitch_prediction.py,sha256=8YRW2rea-is5tYI0Q1bw3bpX7VMFmbpxSP_y6x9Yfug,3893
151
+ careamics/transforms/__init__.py,sha256=WtgpSFL_CJwpa47XzqS7bVXHPJ4qW0TamEymy_kgWQQ,483
152
+ careamics/transforms/compose.py,sha256=ZSVwKg3LT2PrwtSBKtkb6AHsVSSlZIdv9wTYl4To1s4,5682
153
+ careamics/transforms/n2v_manipulate.py,sha256=t9rtMbYV6P1IVp4yzuJfq5-giWyfGrxL8ZhzP29Pp8k,5686
154
+ careamics/transforms/normalize.py,sha256=fxs813ydCWrIzrxFzkbk1gW8OGSr0esQSrNUFSJuGL0,7715
155
+ careamics/transforms/pixel_manipulation.py,sha256=38NsxY8ARvz7GSNDKx5g67Hv5qciBzadiELAg5OcUSU,13355
156
+ careamics/transforms/struct_mask_parameters.py,sha256=jE29Li9sx3olaRnqYfJsSlKi2t0WQzJmCm9aCbIQEsA,421
157
+ careamics/transforms/transform.py,sha256=cEqc4ci8na70i-HIGYC7udRfVa8D_8OjdRVrr3txLvQ,464
158
+ careamics/transforms/tta.py,sha256=78S7Df9rLHmEVSQSI1qDcRrRJGauyG3oaIrXkckCkmw,2335
159
+ careamics/transforms/xy_flip.py,sha256=64BDo8bmAEwO1TNhbIYcUJPzzVmY5ZyNaSNmmGLkn0U,3842
160
+ careamics/transforms/xy_random_rotate90.py,sha256=Kin42yaV4Z8lOwC9nN8gxK73rgnJ2MhCoHHPQmlSgvc,3185
161
+ careamics/utils/__init__.py,sha256=mLwBQ7wTL2EwDwL3NcX53EHPNklojU45Jcc728y4EWQ,402
162
+ careamics/utils/autocorrelation.py,sha256=M_WYzrEOQngc5iSXWar4S3-EOnK6DfYHPC2vVMeu_Bs,945
163
+ careamics/utils/base_enum.py,sha256=bz1D8mDx5V5hdnJ3WAzJXWHJTbgwAky5FprUt9F5cMA,1387
164
+ careamics/utils/context.py,sha256=SoTZfzG6fO4SDOGHOTL2Xlm1n1CSgb9B57GVhrEkFls,1436
165
+ careamics/utils/lightning_utils.py,sha256=DMMmqx-AlNtddBCqm8b_W3h09qUetz32OMPhdDieFwg,1769
166
+ careamics/utils/logging.py,sha256=5U4VsQ4m4OajtirLH6qUjrM1CAc-oXeCsd6JyROjkWE,10337
167
+ careamics/utils/metrics.py,sha256=i9TQNzVF6lUL9c6OwRZFFDhelZfinkEDpWSCKeduscc,10853
168
+ careamics/utils/path_utils.py,sha256=8AugiG5DOmzgSnTCJI8vypXaPE0XhnR-9pzeiFUZ-0I,554
169
+ careamics/utils/plotting.py,sha256=cea1GQB932j2UA3IQZnh-0EenQdnjzPOFoGoFKJ4how,2518
170
+ careamics/utils/ram.py,sha256=tksyn8dVX_iJXmrDZDGub32hFZWIaNxnMheO5G1p43I,244
171
+ careamics/utils/receptive_field.py,sha256=Y2h4c8S6glX3qcx5KHDmO17Kkuyey9voxfoXyqcAfiM,3296
172
+ careamics/utils/serializers.py,sha256=mILUhz75IMpGKnEzcYu9hlOPG8YIiIW09fk6eZM7Y8k,1427
173
+ careamics/utils/torch_utils.py,sha256=_Cf3HdlIRl5hxfpUg9aofCSlcW7GSsIJxsbSORXko0U,3010
174
+ careamics-0.0.7.dist-info/METADATA,sha256=K3w_i8E8INNeZwEQdVtGHDFS_C0031Bfqzn23e8SSO4,3967
175
+ careamics-0.0.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
176
+ careamics-0.0.7.dist-info/entry_points.txt,sha256=2fSNVXJWDJgFLATVj7MkjFNvpl53amG8tUzC3jf7G1s,53
177
+ careamics-0.0.7.dist-info/licenses/LICENSE,sha256=6zdNW-k_xHRKYWUf9tDI_ZplUciFHyj0g16DYuZ2udw,1509
178
+ careamics-0.0.7.dist-info/RECORD,,
@@ -1,162 +0,0 @@
1
- """Custom architecture Pydantic model."""
2
-
3
- from __future__ import annotations
4
-
5
- import inspect
6
- from pprint import pformat
7
- from typing import Any, Literal
8
-
9
- from pydantic import ConfigDict, field_validator, model_validator
10
- from torch.nn import Module
11
- from typing_extensions import Self
12
-
13
- from .architecture_model import ArchitectureModel
14
- from .register_model import get_custom_model
15
-
16
-
17
- class CustomModel(ArchitectureModel):
18
- """Custom model configuration.
19
-
20
- This Pydantic model allows storing parameters for a custom model. In order for the
21
- model to be valid, the specific model needs to be registered using the
22
- `register_model` decorator, and its name correctly passed to this model
23
- configuration (see Examples).
24
-
25
- Attributes
26
- ----------
27
- architecture : Literal["custom"]
28
- Discriminator for the custom model, must be set to "custom".
29
- name : str
30
- Name of the custom model.
31
- parameters : CustomParametersModel
32
- All parameters, required for the initialization of the torch module have to be
33
- passed here.
34
-
35
- Raises
36
- ------
37
- ValueError
38
- If the custom model `name` is unknown.
39
- ValueError
40
- If the custom model is not a torch Module subclass.
41
- ValueError
42
- If the custom model parameters are not valid.
43
-
44
- Examples
45
- --------
46
- >>> from torch import nn, ones
47
- >>> from careamics.config import CustomModel, register_model
48
- >>> # Register a custom model
49
- >>> @register_model(name="my_linear")
50
- ... class LinearModel(nn.Module):
51
- ... def __init__(self, in_features, out_features, *args, **kwargs):
52
- ... super().__init__()
53
- ... self.in_features = in_features
54
- ... self.out_features = out_features
55
- ... self.weight = nn.Parameter(ones(in_features, out_features))
56
- ... self.bias = nn.Parameter(ones(out_features))
57
- ... def forward(self, input):
58
- ... return (input @ self.weight) + self.bias
59
- ...
60
- >>> # Create a configuration
61
- >>> config_dict = {
62
- ... "architecture": "custom",
63
- ... "name": "my_linear",
64
- ... "in_features": 10,
65
- ... "out_features": 5,
66
- ... }
67
- >>> config = CustomModel(**config_dict)
68
- """
69
-
70
- # pydantic model config
71
- model_config = ConfigDict(
72
- extra="allow",
73
- )
74
-
75
- # discriminator used for choosing the pydantic model in Model
76
- architecture: Literal["custom"]
77
- """Name of the architecture."""
78
-
79
- name: str
80
- """Name of the custom model."""
81
-
82
- @field_validator("name")
83
- @classmethod
84
- def custom_model_is_known(cls, value: str) -> str:
85
- """Check whether the custom model is known.
86
-
87
- Parameters
88
- ----------
89
- value : str
90
- Name of the custom model as registered using the `@register_model`
91
- decorator.
92
-
93
- Returns
94
- -------
95
- str
96
- The custom model name.
97
- """
98
- # delegate error to get_custom_model
99
- model = get_custom_model(value)
100
-
101
- # check if it is a torch Module subclass
102
- if not issubclass(model, Module):
103
- raise ValueError(
104
- f'Retrieved class {model} with name "{value}" is not a '
105
- f"torch.nn.Module subclass."
106
- )
107
-
108
- return value
109
-
110
- @model_validator(mode="after")
111
- def check_parameters(self: Self) -> Self:
112
- """Validate model by instantiating the model with the parameters.
113
-
114
- Returns
115
- -------
116
- Self
117
- The validated model.
118
- """
119
- # instantiate model
120
- try:
121
- get_custom_model(self.name)(**self.model_dump())
122
- except Exception as e:
123
- raise ValueError(
124
- f"while passing parameters to the model {e}. Verify that all "
125
- f"mandatory parameters are provided, and that either the {e} accepts "
126
- f"*args and **kwargs in its __init__() method, or that no additional"
127
- f"parameter is provided. Trace: "
128
- f"filename: {inspect.trace()[-1].filename}, function: "
129
- f"{inspect.trace()[-1].function}, line: {inspect.trace()[-1].lineno}"
130
- ) from None
131
-
132
- return self
133
-
134
- def __str__(self) -> str:
135
- """Pretty string representing the configuration.
136
-
137
- Returns
138
- -------
139
- str
140
- Pretty string.
141
- """
142
- return pformat(self.model_dump())
143
-
144
- def model_dump(self, **kwargs: Any) -> dict[str, Any]:
145
- """Dump the model configuration.
146
-
147
- Parameters
148
- ----------
149
- **kwargs : Any
150
- Additional keyword arguments from Pydantic BaseModel model_dump method.
151
-
152
- Returns
153
- -------
154
- dict[str, Any]
155
- Model configuration.
156
- """
157
- model_dict = super().model_dump()
158
-
159
- # remove the name key
160
- model_dict.pop("name")
161
-
162
- return model_dict
@@ -1,103 +0,0 @@
1
- """Custom model registration utilities."""
2
-
3
- from typing import Callable
4
-
5
- from torch.nn import Module
6
-
7
- CUSTOM_MODELS = {} # dictionary of custom models {"name": __class__}
8
-
9
-
10
- def register_model(name: str) -> Callable:
11
- """Decorator used to register a torch.nn.Module class with a given `name`.
12
-
13
- Parameters
14
- ----------
15
- name : str
16
- Name of the model.
17
-
18
- Returns
19
- -------
20
- Callable
21
- Function allowing to instantiate the wrapped Module class.
22
-
23
- Raises
24
- ------
25
- ValueError
26
- If a model is already registered with that name.
27
-
28
- Examples
29
- --------
30
- ```python
31
- @register_model(name="linear")
32
- class LinearModel(nn.Module):
33
- def __init__(self, in_features, out_features):
34
- super().__init__()
35
-
36
- self.weight = nn.Parameter(ones(in_features, out_features))
37
- self.bias = nn.Parameter(ones(out_features))
38
-
39
- def forward(self, input):
40
- return (input @ self.weight) + self.bias
41
- ```
42
- """
43
- if name is None or name == "":
44
- raise ValueError("Model name cannot be empty.")
45
-
46
- if name in CUSTOM_MODELS:
47
- raise ValueError(
48
- f"Model {name} already exists. Choose a different name or run "
49
- f"`clear_custom_models()` to empty the registry."
50
- )
51
-
52
- def add_custom_model(model: Module) -> Module:
53
- """Add a custom model to the registry and return it.
54
-
55
- Parameters
56
- ----------
57
- model : Module
58
- Module class to register.
59
-
60
- Returns
61
- -------
62
- Module
63
- The registered model.
64
- """
65
- # add model to the registry
66
- CUSTOM_MODELS[name] = model
67
-
68
- return model
69
-
70
- return add_custom_model
71
-
72
-
73
- def get_custom_model(name: str) -> Module:
74
- """Get the custom model corresponding to `name` from the registry.
75
-
76
- Parameters
77
- ----------
78
- name : str
79
- Name of the model to retrieve.
80
-
81
- Returns
82
- -------
83
- Module
84
- The requested model.
85
-
86
- Raises
87
- ------
88
- ValueError
89
- If the model is not registered.
90
- """
91
- if name not in CUSTOM_MODELS:
92
- raise ValueError(
93
- f"Model {name} is unknown. Have you registered it using "
94
- f'@register_model("{name}") as decorator?'
95
- )
96
-
97
- return CUSTOM_MODELS[name]
98
-
99
-
100
- def clear_custom_models() -> None:
101
- """Clear the custom models registry."""
102
- # clear dictionary
103
- CUSTOM_MODELS.clear()