careamics 0.0.4.2__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (118) hide show
  1. careamics/__init__.py +17 -2
  2. careamics/careamist.py +239 -28
  3. careamics/cli/conf.py +19 -31
  4. careamics/cli/main.py +112 -12
  5. careamics/cli/utils.py +29 -0
  6. careamics/config/__init__.py +48 -24
  7. careamics/config/algorithms/__init__.py +15 -0
  8. careamics/config/algorithms/care_algorithm_model.py +50 -0
  9. careamics/config/algorithms/n2n_algorithm_model.py +42 -0
  10. careamics/config/algorithms/n2v_algorithm_model.py +35 -0
  11. careamics/config/algorithms/unet_algorithm_model.py +88 -0
  12. careamics/config/{vae_algorithm_model.py → algorithms/vae_algorithm_model.py} +26 -23
  13. careamics/config/architectures/__init__.py +1 -11
  14. careamics/config/architectures/architecture_model.py +3 -3
  15. careamics/config/architectures/lvae_model.py +109 -21
  16. careamics/config/architectures/unet_model.py +1 -0
  17. careamics/config/care_configuration.py +100 -0
  18. careamics/config/configuration.py +354 -0
  19. careamics/config/{configuration_factory.py → configuration_factories.py} +152 -81
  20. careamics/config/configuration_io.py +85 -0
  21. careamics/config/data/__init__.py +10 -0
  22. careamics/config/{data_model.py → data/data_model.py} +58 -198
  23. careamics/config/data/n2v_data_model.py +193 -0
  24. careamics/config/likelihood_model.py +8 -8
  25. careamics/config/loss_model.py +56 -0
  26. careamics/config/n2n_configuration.py +101 -0
  27. careamics/config/n2v_configuration.py +266 -0
  28. careamics/config/nm_model.py +24 -25
  29. careamics/config/support/__init__.py +7 -7
  30. careamics/config/support/supported_algorithms.py +0 -3
  31. careamics/config/support/supported_architectures.py +0 -4
  32. careamics/config/transformations/__init__.py +10 -4
  33. careamics/config/transformations/transform_model.py +3 -3
  34. careamics/config/transformations/transform_unions.py +42 -0
  35. careamics/config/validators/validator_utils.py +3 -3
  36. careamics/dataset/__init__.py +2 -2
  37. careamics/dataset/dataset_utils/__init__.py +3 -3
  38. careamics/dataset/dataset_utils/dataset_utils.py +4 -6
  39. careamics/dataset/dataset_utils/file_utils.py +9 -9
  40. careamics/dataset/dataset_utils/iterate_over_files.py +4 -3
  41. careamics/dataset/dataset_utils/running_stats.py +22 -23
  42. careamics/dataset/in_memory_dataset.py +11 -12
  43. careamics/dataset/iterable_dataset.py +4 -4
  44. careamics/dataset/iterable_pred_dataset.py +2 -1
  45. careamics/dataset/iterable_tiled_pred_dataset.py +2 -1
  46. careamics/dataset/patching/random_patching.py +11 -10
  47. careamics/dataset/patching/sequential_patching.py +26 -26
  48. careamics/dataset/patching/validate_patch_dimension.py +3 -3
  49. careamics/dataset/tiling/__init__.py +2 -2
  50. careamics/dataset/tiling/collate_tiles.py +3 -3
  51. careamics/dataset/tiling/lvae_tiled_patching.py +2 -1
  52. careamics/dataset/tiling/tiled_patching.py +11 -10
  53. careamics/file_io/__init__.py +5 -5
  54. careamics/file_io/read/__init__.py +1 -1
  55. careamics/file_io/read/get_func.py +2 -2
  56. careamics/file_io/write/__init__.py +2 -2
  57. careamics/lightning/__init__.py +5 -5
  58. careamics/lightning/callbacks/__init__.py +1 -1
  59. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +3 -3
  60. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +2 -1
  61. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +2 -1
  62. careamics/lightning/callbacks/progress_bar_callback.py +2 -2
  63. careamics/lightning/lightning_module.py +69 -34
  64. careamics/lightning/train_data_module.py +41 -27
  65. careamics/losses/__init__.py +3 -3
  66. careamics/losses/loss_factory.py +1 -85
  67. careamics/losses/lvae/losses.py +223 -164
  68. careamics/lvae_training/calibration.py +184 -0
  69. careamics/lvae_training/dataset/config.py +2 -2
  70. careamics/lvae_training/dataset/multich_dataset.py +11 -19
  71. careamics/lvae_training/dataset/multifile_dataset.py +3 -2
  72. careamics/lvae_training/dataset/types.py +15 -26
  73. careamics/lvae_training/dataset/utils/index_manager.py +4 -4
  74. careamics/lvae_training/eval_utils.py +125 -213
  75. careamics/model_io/__init__.py +1 -1
  76. careamics/model_io/bioimage/__init__.py +1 -1
  77. careamics/model_io/bioimage/_readme_factory.py +26 -34
  78. careamics/model_io/bioimage/cover_factory.py +171 -0
  79. careamics/model_io/bioimage/model_description.py +56 -34
  80. careamics/model_io/bmz_io.py +42 -42
  81. careamics/model_io/model_io_utils.py +9 -9
  82. careamics/models/layers.py +22 -20
  83. careamics/models/lvae/layers.py +348 -975
  84. careamics/models/lvae/likelihoods.py +10 -8
  85. careamics/models/lvae/lvae.py +214 -275
  86. careamics/models/lvae/noise_models.py +179 -112
  87. careamics/models/lvae/stochastic.py +393 -0
  88. careamics/models/lvae/utils.py +82 -73
  89. careamics/models/model_factory.py +2 -15
  90. careamics/models/unet.py +8 -8
  91. careamics/prediction_utils/__init__.py +1 -1
  92. careamics/prediction_utils/prediction_outputs.py +15 -15
  93. careamics/prediction_utils/stitch_prediction.py +6 -6
  94. careamics/transforms/__init__.py +5 -5
  95. careamics/transforms/compose.py +13 -13
  96. careamics/transforms/n2v_manipulate.py +3 -3
  97. careamics/transforms/pixel_manipulation.py +9 -9
  98. careamics/transforms/xy_random_rotate90.py +4 -4
  99. careamics/utils/__init__.py +5 -5
  100. careamics/utils/context.py +2 -1
  101. careamics/utils/lightning_utils.py +57 -0
  102. careamics/utils/logging.py +11 -10
  103. careamics/utils/serializers.py +2 -0
  104. careamics/utils/torch_utils.py +8 -8
  105. {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/METADATA +16 -13
  106. careamics-0.0.6.dist-info/RECORD +176 -0
  107. {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/WHEEL +1 -1
  108. careamics/config/architectures/custom_model.py +0 -162
  109. careamics/config/architectures/register_model.py +0 -103
  110. careamics/config/configuration_model.py +0 -603
  111. careamics/config/fcn_algorithm_model.py +0 -152
  112. careamics/config/references/__init__.py +0 -45
  113. careamics/config/references/algorithm_descriptions.py +0 -132
  114. careamics/config/references/references.py +0 -39
  115. careamics/config/transformations/transform_union.py +0 -20
  116. careamics-0.0.4.2.dist-info/RECORD +0 -165
  117. {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/entry_points.txt +0 -0
  118. {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/licenses/LICENSE +0 -0
@@ -5,7 +5,7 @@ These functions are used to control certain aspects and behaviours of PyTorch.
5
5
  """
6
6
 
7
7
  import inspect
8
- from typing import Dict, Union
8
+ from typing import Union
9
9
 
10
10
  import torch
11
11
 
@@ -27,12 +27,12 @@ def filter_parameters(
27
27
  ----------
28
28
  func : type
29
29
  Class object.
30
- user_params : Dict
30
+ user_params : dict
31
31
  User provided parameters.
32
32
 
33
33
  Returns
34
34
  -------
35
- Dict
35
+ dict
36
36
  Parameters matching `func`'s signature.
37
37
  """
38
38
  # Get the list of all default parameters
@@ -64,13 +64,13 @@ def get_optimizer(name: str) -> torch.optim.Optimizer:
64
64
  return getattr(torch.optim, name)
65
65
 
66
66
 
67
- def get_optimizers() -> Dict[str, str]:
67
+ def get_optimizers() -> dict[str, str]:
68
68
  """
69
69
  Return the list of all optimizers available in torch.optim.
70
70
 
71
71
  Returns
72
72
  -------
73
- Dict
73
+ dict
74
74
  Optimizers available in torch.optim.
75
75
  """
76
76
  optims = {}
@@ -84,7 +84,7 @@ def get_optimizers() -> Dict[str, str]:
84
84
  def get_scheduler(
85
85
  name: str,
86
86
  ) -> Union[
87
- torch.optim.lr_scheduler.LRScheduler,
87
+ # torch.optim.lr_scheduler.LRScheduler,
88
88
  torch.optim.lr_scheduler.ReduceLROnPlateau,
89
89
  ]:
90
90
  """
@@ -106,13 +106,13 @@ def get_scheduler(
106
106
  return getattr(torch.optim.lr_scheduler, name)
107
107
 
108
108
 
109
- def get_schedulers() -> Dict[str, str]:
109
+ def get_schedulers() -> dict[str, str]:
110
110
  """
111
111
  Return the list of all schedulers available in torch.optim.lr_scheduler.
112
112
 
113
113
  Returns
114
114
  -------
115
- Dict
115
+ dict
116
116
  Schedulers available in torch.optim.lr_scheduler.
117
117
  """
118
118
  schedulers = {}
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.3
1
+ Metadata-Version: 2.4
2
2
  Name: careamics
3
- Version: 0.0.4.2
3
+ Version: 0.0.6
4
4
  Summary: Toolbox for running N2V and friends.
5
5
  Project-URL: homepage, https://careamics.github.io/
6
6
  Project-URL: repository, https://github.com/CAREamics/careamics
@@ -16,19 +16,21 @@ Classifier: Programming Language :: Python :: 3.11
16
16
  Classifier: Programming Language :: Python :: 3.12
17
17
  Classifier: Typing :: Typed
18
18
  Requires-Python: >=3.9
19
- Requires-Dist: bioimageio-core>=0.6.9
19
+ Requires-Dist: bioimageio-core==0.7
20
20
  Requires-Dist: numpy<2.0.0
21
- Requires-Dist: psutil
22
- Requires-Dist: pydantic<2.9,>=2.5
23
- Requires-Dist: pytorch-lightning>=2.2.0
24
- Requires-Dist: pyyaml
25
- Requires-Dist: scikit-image<=0.23.2
26
- Requires-Dist: tifffile
27
- Requires-Dist: torch>=2.0.0
28
- Requires-Dist: torchvision
29
- Requires-Dist: typer==0.12.3
21
+ Requires-Dist: pillow<=11.1.0
22
+ Requires-Dist: psutil<=6.1.1
23
+ Requires-Dist: pydantic<2.11,>=2.5
24
+ Requires-Dist: pytorch-lightning<=2.5.0.post0,>=2.2
25
+ Requires-Dist: pyyaml!=6.0.0,<=6.0.2
26
+ Requires-Dist: scikit-image<=0.25.0
27
+ Requires-Dist: tifffile<=2025.1.10
28
+ Requires-Dist: torch<=2.5.1,>=2.0
29
+ Requires-Dist: torchvision<=0.20.1
30
+ Requires-Dist: typer<=0.15.1,>=0.12.3
30
31
  Requires-Dist: zarr<3.0.0
31
32
  Provides-Extra: dev
33
+ Requires-Dist: onnx; extra == 'dev'
32
34
  Requires-Dist: pre-commit; extra == 'dev'
33
35
  Requires-Dist: pytest; extra == 'dev'
34
36
  Requires-Dist: pytest-cov; extra == 'dev'
@@ -38,7 +40,7 @@ Requires-Dist: careamics-portfolio; extra == 'examples'
38
40
  Requires-Dist: jupyter; extra == 'examples'
39
41
  Requires-Dist: matplotlib; extra == 'examples'
40
42
  Provides-Extra: tensorboard
41
- Requires-Dist: protobuf==3.20.3; extra == 'tensorboard'
43
+ Requires-Dist: protobuf==5.29.1; extra == 'tensorboard'
42
44
  Requires-Dist: tensorboard; extra == 'tensorboard'
43
45
  Provides-Extra: wandb
44
46
  Requires-Dist: wandb; extra == 'wandb'
@@ -57,6 +59,7 @@ Description-Content-Type: text/markdown
57
59
  [![Python Version](https://img.shields.io/pypi/pyversions/careamics.svg?color=green)](https://python.org)
58
60
  [![CI](https://github.com/CAREamics/careamics/actions/workflows/ci.yml/badge.svg)](https://github.com/CAREamics/careamics/actions/workflows/ci.yml)
59
61
  [![codecov](https://codecov.io/gh/CAREamics/careamics/branch/main/graph/badge.svg)](https://codecov.io/gh/CAREamics/careamics)
62
+ [![Image.sc](https://img.shields.io/badge/Got%20a%20question%3F-Image.sc-blue)](https://forum.image.sc/)
60
63
 
61
64
 
62
65
  CAREamics is a PyTorch library aimed at simplifying the use of Noise2Void and its many
@@ -0,0 +1,176 @@
1
+ careamics/__init__.py,sha256=WF2JpQmC-MmuSB0L81XRo67NwaN_0qjyywcpRlbVJVE,569
2
+ careamics/careamist.py,sha256=rakSaDSGRR0Pr0o1s8ejwIrjWknOQs4DgljukqfyWu0,37635
3
+ careamics/conftest.py,sha256=Od4WcaaP0UP-XUMrFr_oo4e6c2hi_RvNbuaRTopwlmI,911
4
+ careamics/py.typed,sha256=esB4cHc6c07uVkGtqf8at7ttEnprwRxwk8obY8Qumq4,187
5
+ careamics/cli/__init__.py,sha256=LbM9bVtU1dy-khmdiIDXwvKy2v8wPBCEUuWqV_8rosA,106
6
+ careamics/cli/conf.py,sha256=oixGRZNylW-NTM_rkDtQkSRw8KUYwtmUC_hK5BEeLnA,13074
7
+ careamics/cli/main.py,sha256=S4B3c1ZN-OQK0l2_W42CaW0KmF_Pe_y4pKgn_UOuyDg,6564
8
+ careamics/cli/utils.py,sha256=q_dmG7lxg_FT62qX9fPilIWL1M8ibhLnnhUKqa4knPI,660
9
+ careamics/config/__init__.py,sha256=c5FXtcQrtROHdLRl4cHpKo6V4_E4yr6HxYgNwqH9CHg,1917
10
+ careamics/config/callback_model.py,sha256=EeYHqpMIPQwyNxLRzzX32Uncl5mZuB1bJO76RHpNymg,4555
11
+ careamics/config/care_configuration.py,sha256=hSfNJ-dooHm4ujG6Q3Hawr8zDeYy1HNiebtM7gxsh7s,2714
12
+ careamics/config/configuration.py,sha256=KmLeXHkFhQTrcru1erhfVf3tHvQdoi12ls_u254rtDw,11114
13
+ careamics/config/configuration_factories.py,sha256=-kdoLQ0dkV4xlcUx_odrtMsvL4NqGYg5-4HVl3PqY4c,32423
14
+ careamics/config/configuration_io.py,sha256=ks9R8lRCBY_m0sdy1k4ZWFPKEFSp7K9X47jCG4d0FY4,2353
15
+ careamics/config/inference_model.py,sha256=UE_-ZmCX6LFCbDBOwyGnvuAboF_JNX2m2LcF0WiwgCI,6961
16
+ careamics/config/likelihood_model.py,sha256=VorUtc0_-xIWNxwVrd1kBba-003ICdVMtxpcDCxH4Io,2259
17
+ careamics/config/loss_model.py,sha256=yYcUBS90Qyon1MxeaHiVP3dJHPJFC0GUvWKGcAb3IHk,2036
18
+ careamics/config/n2n_configuration.py,sha256=QoHnEak0LiG6rBIXnmx61Chkz-Q0jr2Qt9I0NnmgZo4,2667
19
+ careamics/config/n2v_configuration.py,sha256=5MagUFaVZ-9VGIiaKD-4RRkFpJen2imMoR3Lu07BD14,8489
20
+ careamics/config/nm_model.py,sha256=5dAhDBLa4WPfKaNEK6ATNsSUwtlH8u8gYweEA4gZP6g,4758
21
+ careamics/config/optimizer_models.py,sha256=OWpTydRBBR8wt_af1mZHNNwvL_RtnRFopAOdgjzLo30,5750
22
+ careamics/config/tile_information.py,sha256=c-_xrVPOgcnjiEzQ-9A_GhNPamObkMANbeHaRP29R-4,2059
23
+ careamics/config/training_model.py,sha256=67_ipo_-LxhT4-WqAs40Sg8PjU--my43Qn3BhjvlXxM,3212
24
+ careamics/config/algorithms/__init__.py,sha256=on5D6zBO9Lu-Tf5To8xpF6owIFqdN7RSmZdpyXDOaNw,404
25
+ careamics/config/algorithms/care_algorithm_model.py,sha256=mdMDc0Vr0V7LZmksdjNDlRx3ofwy57IMXygY_vg6hPY,1175
26
+ careamics/config/algorithms/n2n_algorithm_model.py,sha256=FO04TRecvMPikz_s2WupwODzi4g3x84g1HXRPgMAZWo,1015
27
+ careamics/config/algorithms/n2v_algorithm_model.py,sha256=TBAUKeFxdRtomYhFrVqfq31hpOiNQO2YOogYzVVFknM,919
28
+ careamics/config/algorithms/unet_algorithm_model.py,sha256=OaBFVlhsb9YhF3f2x1ImazvfnZ4_DPWvYWihRwurkeg,2587
29
+ careamics/config/algorithms/vae_algorithm_model.py,sha256=1ZrShUGHA7zbjSiCQwhUSv7l7Hr7MbpcLOw8ucM27p8,4680
30
+ careamics/config/architectures/__init__.py,sha256=lYUz56R7LDqVQWQDLkLgJL8VtOyxc_ege1r4bXGEBqA,220
31
+ careamics/config/architectures/architecture_model.py,sha256=gXn4gdLrQP3bmTQxIhzkEHYlodPaIp6LI-kwZl23W-Y,911
32
+ careamics/config/architectures/lvae_model.py,sha256=lwOiJYNUqPrZFl9SpPLYon77EiRbe2eI7pmpx45rO78,7606
33
+ careamics/config/architectures/unet_model.py,sha256=HJJWf-wuYTv8KXMwikdjeB8htg1QOt0IyQbMuBt1LUI,3556
34
+ careamics/config/data/__init__.py,sha256=ijFNSRKkKVF8fw6ym8kXq_wNhdwkWXxuNj2XKtD3KjE,218
35
+ careamics/config/data/data_model.py,sha256=lSAIjpxuLaw1GMJ0rOHOZghYBwqWwlZf5sCphwAKMUc,11111
36
+ careamics/config/data/n2v_data_model.py,sha256=-n5cncmcrd4-KUSmEk8rXAONMC7sfbYfGmshVNuySCU,5915
37
+ careamics/config/support/__init__.py,sha256=ktWvxbTkRXnQPS_N84l9E2B5kTZVdd64SIjsJIQKB-k,1041
38
+ careamics/config/support/supported_activations.py,sha256=CqOWoziIK5jZZXJO7G7cGg3TTid1POqv8FXqxjXxyME,535
39
+ careamics/config/support/supported_algorithms.py,sha256=ivLECrplTLjSUsa7AyT1n2aOo_WM3sMJJlW3OOT1fPk,846
40
+ careamics/config/support/supported_architectures.py,sha256=pOxvHOAIUkc7HCO0IIg4K22h-Ti5ErtcIkGOjN-zh1s,340
41
+ careamics/config/support/supported_data.py,sha256=T_mDiWLFMVji_EpjBABUObAJcnv-XBnqp9XUZP37Tdk,2902
42
+ careamics/config/support/supported_loggers.py,sha256=ubSOkGoYabGbm_jmyc1R3eFcvcP-sHmuyiBi_d3_wLg,197
43
+ careamics/config/support/supported_losses.py,sha256=2x5sZuxRbWJzodoL35I1mMYUUDMzk8UFiFdbyPwbJ4E,583
44
+ careamics/config/support/supported_optimizers.py,sha256=_2XmwzYENB6xpTedyWHUdWuGcDzdlfEAJjzm_qI3yRM,1392
45
+ careamics/config/support/supported_pixel_manipulations.py,sha256=rFiktUlvoFU7s1NAKEMqsXOzLw5eaw9GtCKUznvq6xc,432
46
+ careamics/config/support/supported_struct_axis.py,sha256=alZMA5Y-BpDymLPUEd1zqVY0xMkgl9Rv1d4ujED6sco,424
47
+ careamics/config/support/supported_transforms.py,sha256=ODvmoTywvJWG_5-SJJZu-X1FNtKGhkNWQc-t26IFZWI,311
48
+ careamics/config/transformations/__init__.py,sha256=jMTUX15n8ZF4Nc9gQp-qbXfhj-iEsEa65lzNbHwzyfY,631
49
+ careamics/config/transformations/n2v_manipulate_model.py,sha256=Mdxc4J3vxe_dM2CIhmTwwGOIirQvrQXLoa2vRsTzoYI,1855
50
+ careamics/config/transformations/normalize_model.py,sha256=1Rkk6IkF-7ytGU6HSzP-TpOi4RRWiQJ6fOd8zammXcg,1936
51
+ careamics/config/transformations/transform_model.py,sha256=6UVbXnxm-LLZOQQ-ZBwWwgmS_99DiBuERLfMxrta3-8,990
52
+ careamics/config/transformations/transform_unions.py,sha256=uqlI8Nm827bKfMbDQLVhKFtT9e7TJ_zIYDBdHlOuQ1I,1137
53
+ careamics/config/transformations/xy_flip_model.py,sha256=zU-uZ1b1zNZWckbho3onN-B7BHKhN7jbgbNZyRQhv2s,1025
54
+ careamics/config/transformations/xy_random_rotate90_model.py,sha256=6sYKmtCLvz0SV1qZgBSHUTH-CUjwvHnohq1HyPntbyE,894
55
+ careamics/config/validators/__init__.py,sha256=iv0nVI0W7j9DxFPwh0DjRCzM9P8oLQn4Gwi5rfuFrrI,180
56
+ careamics/config/validators/validator_utils.py,sha256=NVkEOr5AQK4JXWNtmgeQgAaJOyieJNb5PHCjlcqNeew,2611
57
+ careamics/dataset/__init__.py,sha256=31vop67zbtGesENEIig-LLw1q2lCydMFc_YWgfK2Yt4,547
58
+ careamics/dataset/in_memory_dataset.py,sha256=MV_Vf4siIP-g7VKhxN4rU7MZXpaHKvfwr8ZXqk44Qhs,9958
59
+ careamics/dataset/in_memory_pred_dataset.py,sha256=VvwW5D8TjgO_kR8eZinP-9qepSiI6ZsUN7FZ0Rvc8Bs,2161
60
+ careamics/dataset/in_memory_tiled_pred_dataset.py,sha256=DANmlnlV1ysXKdwGvmJoOYKcjlgoMhnSGSDRpeK79ZA,3552
61
+ careamics/dataset/iterable_dataset.py,sha256=pqtm-AWhDbuZTnXf0roAHWVxGPRTekAzVcJHLzaSyFU,9797
62
+ careamics/dataset/iterable_pred_dataset.py,sha256=jee4b8bZOyvSS5qfIsb6Jkk1EV_MKEU2SyZ0m7p0p9k,3767
63
+ careamics/dataset/iterable_tiled_pred_dataset.py,sha256=2j_kLMB6DfSKXPszZPYgsB08TVgcf1V5HY_kZVozrFM,4560
64
+ careamics/dataset/zarr_dataset.py,sha256=lojnK5bhiF1vyjuPtWXBrZ9sy5fT_rBvZJbbbnE-H_I,5665
65
+ careamics/dataset/dataset_utils/__init__.py,sha256=MJ3xriL6R4ZtmzbvLsASUWLb85Hk5AdeRaYnHpNELJQ,507
66
+ careamics/dataset/dataset_utils/dataset_utils.py,sha256=X83DzaOWmHdl4eOPac2IQJH3bPA43RVq0hPrFrzvIXQ,2630
67
+ careamics/dataset/dataset_utils/file_utils.py,sha256=ru6AtQ9LCmo6raN1-GnJEN4UyP1PbmSdR9MEys3CuHo,4094
68
+ careamics/dataset/dataset_utils/iterate_over_files.py,sha256=Jun35Qn9XevHOb_DixYBMHDAOykLmiwciA5Q2MzSUK8,2912
69
+ careamics/dataset/dataset_utils/running_stats.py,sha256=kWorioMH4S5uZj2cvUpjHB6cIUhMFa1XXwDQrrKIWdI,5752
70
+ careamics/dataset/patching/__init__.py,sha256=7-s12oUAZNlMOwSkxSwbD7vojQINWYFzn_4qIJ87WBg,37
71
+ careamics/dataset/patching/patching.py,sha256=deAxY34Iz-mguBlHQ-5EO4vRhPpR9I3LQ9onV1K_KqA,8858
72
+ careamics/dataset/patching/random_patching.py,sha256=gm1jxye9yvHbdijLzCtDSzRU_9j110GRLMnJaUwLAHQ,6487
73
+ careamics/dataset/patching/sequential_patching.py,sha256=4F5E1Ta0M5kFXGwI2-QXRxeOx0CyUwbFaB5awkMCN_Q,5890
74
+ careamics/dataset/patching/validate_patch_dimension.py,sha256=mC2bZWBpU44NEvXxEfR7ULUKwWPuZPjmBWpHYJxNDWc,2121
75
+ careamics/dataset/tiling/__init__.py,sha256=aW_AMB9rzm0VmooUpjcyqv6sQP69RlPQMEdP2sVjdz8,190
76
+ careamics/dataset/tiling/collate_tiles.py,sha256=XK0BsDQE7XwIwmOoCHJIpVC3kqjSN6nDhrJ4POVeHS8,965
77
+ careamics/dataset/tiling/lvae_tiled_patching.py,sha256=LYEEdjKuKaxIGFtOkhfpsE7hruBnIsD5HcW9aVH6WHI,13019
78
+ careamics/dataset/tiling/tiled_patching.py,sha256=6vxsqlccUqIl4Ys92JWIPs0Kn95VzaHoAYMSGcp2dh8,5956
79
+ careamics/file_io/__init__.py,sha256=vgMI77X820VOWywAEW5W20FXfmbqBzx4V63D3V3_HhI,334
80
+ careamics/file_io/read/__init__.py,sha256=wf8O_o80ghrlWQ-RGEuSqcc2LU55P1B-oxTacDToygo,259
81
+ careamics/file_io/read/get_func.py,sha256=O_pdymjh2mc-JZ1je3ZnPAcsHc7Je3a005AMgAa0xuw,1388
82
+ careamics/file_io/read/tiff.py,sha256=UMofW33rvByK9B1zYGhSrWAiAA3uQUV3OVK7cq9d0gQ,1359
83
+ careamics/file_io/read/zarr.py,sha256=2jzREAnJDQSv0qmsL-v00BxmiZ_sp0ijq667LZSQ_hY,1685
84
+ careamics/file_io/write/__init__.py,sha256=CUt33cRjG9hm18L9a7XqaUKWQ_3xiuQ9ztz4Ab7RYG0,283
85
+ careamics/file_io/write/get_func.py,sha256=hyGHe1RX-lfa9QFAnwRCz_gS0NRiRnXEtg4Bdeh2Esc,1627
86
+ careamics/file_io/write/tiff.py,sha256=tBGIgl-I1sMyBivgx-dOTBykXBODkgwPH8MT3_4KAE8,1050
87
+ careamics/lightning/__init__.py,sha256=ATCVAGnX08Ik4TxbIv0-cXb52UinR42JgvZh_GIMSpc,588
88
+ careamics/lightning/lightning_module.py,sha256=3nzRcjv7dlWmWjoWmGaBV7D6d698xMp3XO7fpIrZhwA,22630
89
+ careamics/lightning/predict_data_module.py,sha256=JNwujK6QwObSx6P25ghpGl2f2gGT3KVgYMTlonZzH20,12745
90
+ careamics/lightning/train_data_module.py,sha256=7fti07y2TTV_Airl-D9FRTJf7AmsFmErSKnWUGLQE58,28699
91
+ careamics/lightning/callbacks/__init__.py,sha256=eA5ltzYNzuO0uMEr1jG4wP01b0s29s5I03WGJ290qkw,312
92
+ careamics/lightning/callbacks/hyperparameters_callback.py,sha256=u45knOZHwoVHz6yYfrnERQuozT_SfZ1OrKP0QjeU4EM,1495
93
+ careamics/lightning/callbacks/progress_bar_callback.py,sha256=RilGAVUa90AlCLdooIGJF2cAcHAjGn24zKZiBmRRkwg,2438
94
+ careamics/lightning/callbacks/prediction_writer_callback/__init__.py,sha256=ZVf3vaSU_NjSjrKbI24H0kK9WAiP9oKXfhP670EaWMo,548
95
+ careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py,sha256=i4vGGiVLslafi-5iuvkAKzBgZ0BpwTTxSTo31oViFz4,1480
96
+ careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py,sha256=8HHUSKcG7G0FSCVPnpGQHLfpara5mnKAwsiiyWp2wzo,8210
97
+ careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py,sha256=lxsLjLskRpYnzdyWCdOICUJxF9YzuUi1RH0LJnOCVgo,12594
98
+ careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py,sha256=F1IpbNNgkv5eK8Xpqp7wqv2lsqEdP1wMRlBL7RBn93U,7114
99
+ careamics/losses/__init__.py,sha256=nSWbkBcFhkyUkIT2wVcULqpieyY2Oro39NXZTtfQpXo,351
100
+ careamics/losses/loss_factory.py,sha256=oPacrkwiabsmiW_r--IxX-XPRbzezZUvOuWKbUw5LiI,1518
101
+ careamics/losses/fcn/__init__.py,sha256=kf92MKFGHr6upiztZVgWwtGPf734DZyub92Rn8uEq8o,18
102
+ careamics/losses/fcn/losses.py,sha256=NdOz29hzJ7D26p13q-g0NWoYwNauIWrP2xWww6YPbB8,2360
103
+ careamics/losses/lvae/__init__.py,sha256=0FNtMLHrOMfagtWkaBdz1NTjyf2y0QLgysxJv5jq5uw,19
104
+ careamics/losses/lvae/loss_utils.py,sha256=QxzA2N1TglR4H0X0uyTWWytDagE1lA9IB_TK1lms3ao,2720
105
+ careamics/losses/lvae/losses.py,sha256=wHT1dx04BZ_OI-_S7cFQ5hFmMetm6FSnuZfwZBBtIpY,17977
106
+ careamics/lvae_training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
107
+ careamics/lvae_training/calibration.py,sha256=CvtmRC1s-2XHDjt1XG3RdHlPgCOaoCweJMKZYmDmgnU,6508
108
+ careamics/lvae_training/eval_utils.py,sha256=RU1FYK9z1Hno6b6n7XyTTOu6An0MX0_6C8Og9FFqpyM,30949
109
+ careamics/lvae_training/get_config.py,sha256=dwVfaQS7nzjQss0E1gGLUpQpjPcOWwLgIhbu3Z0I1rg,3068
110
+ careamics/lvae_training/lightning_module.py,sha256=ryr7iHqCMzCl5esi6_gEcnKFDQkMrw0EXK9Zfgv1Nek,27186
111
+ careamics/lvae_training/metrics.py,sha256=KTDAKhe3vh-YxzGibjtkIG2nnUyujbnwqX4xGwaRXwE,6718
112
+ careamics/lvae_training/train_lvae.py,sha256=lJEBlBGdISVkZBcEnPNRYgJ7VbapYzZHRaFOrZ0xYGE,11080
113
+ careamics/lvae_training/train_utils.py,sha256=e-d4QsF-li8MmAPkAmB1daHpkuU16nBTnQFZYqpTjn4,3567
114
+ careamics/lvae_training/dataset/__init__.py,sha256=dvdHHaRA9ZfOt_uOnXkYyra2_b0Wsxs8qmrze6zxJAE,377
115
+ careamics/lvae_training/dataset/config.py,sha256=hGIggj5uOZrFBK54o9vii0sG5WGhF_E32URKIIzQMec,4342
116
+ careamics/lvae_training/dataset/lc_dataset.py,sha256=xErygllUu6Q-PfPZ24sHf5_NP7YGHD2NVyzmDZgDd2U,10697
117
+ careamics/lvae_training/dataset/multich_dataset.py,sha256=J1QWXlTSLZ40D3MFKw2StarZpq82sFeaHSXk7j48RAc,41608
118
+ careamics/lvae_training/dataset/multifile_dataset.py,sha256=hJBs6iBrf_FcyUYzg8rDjvKEICHxDYyXVOj-5L0F6FE,10273
119
+ careamics/lvae_training/dataset/types.py,sha256=zfi-zMmMe7GTaX-MYrYfVbAM4D2LPHrJkmqSFl9ulxA,632
120
+ careamics/lvae_training/dataset/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
121
+ careamics/lvae_training/dataset/utils/data_utils.py,sha256=8PvRPqSbYHPCl87cycZHxXIFOT_EoBV-8XCt3ZLh36s,3125
122
+ careamics/lvae_training/dataset/utils/empty_patch_fetcher.py,sha256=OFjeqhZ6vFULJsF5tnoByhEhE8aLHujFToU_yyqMCP4,2266
123
+ careamics/lvae_training/dataset/utils/index_manager.py,sha256=Gt1I-7lBaQDBgqguOmofAFDdAsJQfz7ktvq4_I80F9c,10084
124
+ careamics/lvae_training/dataset/utils/index_switcher.py,sha256=ZoMi8LsaIkm8MFqIFaxN4oQGyzCcwOlCom8SYNus15E,6716
125
+ careamics/model_io/__init__.py,sha256=khMIkk107LL5JGze0OVfl5Lfi14R3_e4W21tW0iJ1kE,155
126
+ careamics/model_io/bmz_io.py,sha256=0fL0cz_80nsu_Jgk-ImXKTNbn_WSdrXB9flvLm2g3Is,7792
127
+ careamics/model_io/model_io_utils.py,sha256=HCSoNvaOo55kI7teZleN57riqNK2fLq77zMadBKXCyc,2777
128
+ careamics/model_io/bioimage/__init__.py,sha256=dKm-UluVwa_7EHQB0ukx-Qk8JVWtuT-OUinY8hE9EIw,298
129
+ careamics/model_io/bioimage/_readme_factory.py,sha256=sRfHbuwhfYmpwsG0keXFCSK3qCS4VI4GQhX6OhSKLjY,3440
130
+ careamics/model_io/bioimage/bioimage_utils.py,sha256=YVr75SDfafOiuYGonPbcsO-xVS0wI1WkmWQQZx6DXYQ,1246
131
+ careamics/model_io/bioimage/cover_factory.py,sha256=8URrpEfJvdHBJeSrh5H2IQHSUybsTyAOR3_A-YYAAlw,4583
132
+ careamics/model_io/bioimage/model_description.py,sha256=HwjWT77nIt6LkLvIBRxwgPOsL0ffrCMtDUjim39qgqE,10071
133
+ careamics/models/__init__.py,sha256=Xui2BLJd1I2r_E3Sj24fJALFTi2FGtfNscUWj_0c9Hk,93
134
+ careamics/models/activation.py,sha256=nu3sDgd7Lsyw8rvmUwxNN-7SM09cEMfxZ9DRDzdSKns,1049
135
+ careamics/models/layers.py,sha256=tpsxbolRWYycZGxS3hKlDRtMf6HpNdZs98uwx5K8lls,13757
136
+ careamics/models/model_factory.py,sha256=GWbouERvEfHj_BrpKHYgrPAj8dpdoh1R-X--Jt09N1Q,1370
137
+ careamics/models/unet.py,sha256=9m8GxsTXX9c0mC-eVe2ZXQn2afose1CG6Z8vIhELb7I,14308
138
+ careamics/models/lvae/__init__.py,sha256=6dT6uqgT__V08EjoTGxXguTbTkySZmByS9J2Bj6WWLM,53
139
+ careamics/models/lvae/layers.py,sha256=UPxZiZjgnBnPs_wdSzcP-_s17MEw4P1CoIZzn4OdUA0,57944
140
+ careamics/models/lvae/likelihoods.py,sha256=SHsaZZTjGg8TO6JkvjdQOZtu8CCOvlSxtqRBZhrGXKk,12098
141
+ careamics/models/lvae/lvae.py,sha256=AWll857rAnlvjTom5CX0CpbEcf-ci_Icoz5wFKVLu5E,34220
142
+ careamics/models/lvae/noise_models.py,sha256=dFWM9DDJ7qzsyiT8sDWR8THbEiZmu3XnW5UzGJl7Mck,21834
143
+ careamics/models/lvae/stochastic.py,sha256=019M6BBR6GtYjUVF6pcOTOsfEAYeg0vclz55V6Fl4yY,16588
144
+ careamics/models/lvae/utils.py,sha256=EE3paHu3vhCaqfOrGypzUsImZJO94uBhx8q6kZ-R36o,11516
145
+ careamics/prediction_utils/__init__.py,sha256=k5hsPGY8FOkwIT0fQgrUz7fVCH2NlwuOdZiISdXjEWg,270
146
+ careamics/prediction_utils/lvae_prediction.py,sha256=ZwPFCSeUGsULIMoMQWRYKHfLFaDm7UKyGaUMVfSUqfs,6210
147
+ careamics/prediction_utils/lvae_tiling_manager.py,sha256=SI-JaJvLrKWBSHdm-FjcqWdbhlcflTRiKxYF7CSGzvA,13736
148
+ careamics/prediction_utils/prediction_outputs.py,sha256=fw-bJ2szWJD7BgZlECmxy5sgeXGFJl4T8cRNzLR1aUQ,4069
149
+ careamics/prediction_utils/stitch_prediction.py,sha256=8YRW2rea-is5tYI0Q1bw3bpX7VMFmbpxSP_y6x9Yfug,3893
150
+ careamics/transforms/__init__.py,sha256=WtgpSFL_CJwpa47XzqS7bVXHPJ4qW0TamEymy_kgWQQ,483
151
+ careamics/transforms/compose.py,sha256=ZSVwKg3LT2PrwtSBKtkb6AHsVSSlZIdv9wTYl4To1s4,5682
152
+ careamics/transforms/n2v_manipulate.py,sha256=t9rtMbYV6P1IVp4yzuJfq5-giWyfGrxL8ZhzP29Pp8k,5686
153
+ careamics/transforms/normalize.py,sha256=fxs813ydCWrIzrxFzkbk1gW8OGSr0esQSrNUFSJuGL0,7715
154
+ careamics/transforms/pixel_manipulation.py,sha256=38NsxY8ARvz7GSNDKx5g67Hv5qciBzadiELAg5OcUSU,13355
155
+ careamics/transforms/struct_mask_parameters.py,sha256=jE29Li9sx3olaRnqYfJsSlKi2t0WQzJmCm9aCbIQEsA,421
156
+ careamics/transforms/transform.py,sha256=cEqc4ci8na70i-HIGYC7udRfVa8D_8OjdRVrr3txLvQ,464
157
+ careamics/transforms/tta.py,sha256=78S7Df9rLHmEVSQSI1qDcRrRJGauyG3oaIrXkckCkmw,2335
158
+ careamics/transforms/xy_flip.py,sha256=64BDo8bmAEwO1TNhbIYcUJPzzVmY5ZyNaSNmmGLkn0U,3842
159
+ careamics/transforms/xy_random_rotate90.py,sha256=Kin42yaV4Z8lOwC9nN8gxK73rgnJ2MhCoHHPQmlSgvc,3185
160
+ careamics/utils/__init__.py,sha256=mLwBQ7wTL2EwDwL3NcX53EHPNklojU45Jcc728y4EWQ,402
161
+ careamics/utils/autocorrelation.py,sha256=M_WYzrEOQngc5iSXWar4S3-EOnK6DfYHPC2vVMeu_Bs,945
162
+ careamics/utils/base_enum.py,sha256=bz1D8mDx5V5hdnJ3WAzJXWHJTbgwAky5FprUt9F5cMA,1387
163
+ careamics/utils/context.py,sha256=SoTZfzG6fO4SDOGHOTL2Xlm1n1CSgb9B57GVhrEkFls,1436
164
+ careamics/utils/lightning_utils.py,sha256=DMMmqx-AlNtddBCqm8b_W3h09qUetz32OMPhdDieFwg,1769
165
+ careamics/utils/logging.py,sha256=5U4VsQ4m4OajtirLH6qUjrM1CAc-oXeCsd6JyROjkWE,10337
166
+ careamics/utils/metrics.py,sha256=yAoCvrZ1kQx-kT9xdTBYz-oh0I52ef6uBnw8qgzpwn8,10318
167
+ careamics/utils/path_utils.py,sha256=8AugiG5DOmzgSnTCJI8vypXaPE0XhnR-9pzeiFUZ-0I,554
168
+ careamics/utils/ram.py,sha256=tksyn8dVX_iJXmrDZDGub32hFZWIaNxnMheO5G1p43I,244
169
+ careamics/utils/receptive_field.py,sha256=Y2h4c8S6glX3qcx5KHDmO17Kkuyey9voxfoXyqcAfiM,3296
170
+ careamics/utils/serializers.py,sha256=mILUhz75IMpGKnEzcYu9hlOPG8YIiIW09fk6eZM7Y8k,1427
171
+ careamics/utils/torch_utils.py,sha256=_Cf3HdlIRl5hxfpUg9aofCSlcW7GSsIJxsbSORXko0U,3010
172
+ careamics-0.0.6.dist-info/METADATA,sha256=Hz6RQlh5szAGllr7baFnc86U_sSC2fCjvU4FhjmhYMk,3898
173
+ careamics-0.0.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
174
+ careamics-0.0.6.dist-info/entry_points.txt,sha256=2fSNVXJWDJgFLATVj7MkjFNvpl53amG8tUzC3jf7G1s,53
175
+ careamics-0.0.6.dist-info/licenses/LICENSE,sha256=6zdNW-k_xHRKYWUf9tDI_ZplUciFHyj0g16DYuZ2udw,1509
176
+ careamics-0.0.6.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: hatchling 1.25.0
2
+ Generator: hatchling 1.27.0
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
@@ -1,162 +0,0 @@
1
- """Custom architecture Pydantic model."""
2
-
3
- from __future__ import annotations
4
-
5
- import inspect
6
- from pprint import pformat
7
- from typing import Any, Literal
8
-
9
- from pydantic import ConfigDict, field_validator, model_validator
10
- from torch.nn import Module
11
- from typing_extensions import Self
12
-
13
- from .architecture_model import ArchitectureModel
14
- from .register_model import get_custom_model
15
-
16
-
17
- class CustomModel(ArchitectureModel):
18
- """Custom model configuration.
19
-
20
- This Pydantic model allows storing parameters for a custom model. In order for the
21
- model to be valid, the specific model needs to be registered using the
22
- `register_model` decorator, and its name correctly passed to this model
23
- configuration (see Examples).
24
-
25
- Attributes
26
- ----------
27
- architecture : Literal["custom"]
28
- Discriminator for the custom model, must be set to "custom".
29
- name : str
30
- Name of the custom model.
31
- parameters : CustomParametersModel
32
- All parameters, required for the initialization of the torch module have to be
33
- passed here.
34
-
35
- Raises
36
- ------
37
- ValueError
38
- If the custom model `name` is unknown.
39
- ValueError
40
- If the custom model is not a torch Module subclass.
41
- ValueError
42
- If the custom model parameters are not valid.
43
-
44
- Examples
45
- --------
46
- >>> from torch import nn, ones
47
- >>> from careamics.config import CustomModel, register_model
48
- >>> # Register a custom model
49
- >>> @register_model(name="my_linear")
50
- ... class LinearModel(nn.Module):
51
- ... def __init__(self, in_features, out_features, *args, **kwargs):
52
- ... super().__init__()
53
- ... self.in_features = in_features
54
- ... self.out_features = out_features
55
- ... self.weight = nn.Parameter(ones(in_features, out_features))
56
- ... self.bias = nn.Parameter(ones(out_features))
57
- ... def forward(self, input):
58
- ... return (input @ self.weight) + self.bias
59
- ...
60
- >>> # Create a configuration
61
- >>> config_dict = {
62
- ... "architecture": "custom",
63
- ... "name": "my_linear",
64
- ... "in_features": 10,
65
- ... "out_features": 5,
66
- ... }
67
- >>> config = CustomModel(**config_dict)
68
- """
69
-
70
- # pydantic model config
71
- model_config = ConfigDict(
72
- extra="allow",
73
- )
74
-
75
- # discriminator used for choosing the pydantic model in Model
76
- architecture: Literal["custom"]
77
- """Name of the architecture."""
78
-
79
- name: str
80
- """Name of the custom model."""
81
-
82
- @field_validator("name")
83
- @classmethod
84
- def custom_model_is_known(cls, value: str) -> str:
85
- """Check whether the custom model is known.
86
-
87
- Parameters
88
- ----------
89
- value : str
90
- Name of the custom model as registered using the `@register_model`
91
- decorator.
92
-
93
- Returns
94
- -------
95
- str
96
- The custom model name.
97
- """
98
- # delegate error to get_custom_model
99
- model = get_custom_model(value)
100
-
101
- # check if it is a torch Module subclass
102
- if not issubclass(model, Module):
103
- raise ValueError(
104
- f'Retrieved class {model} with name "{value}" is not a '
105
- f"torch.nn.Module subclass."
106
- )
107
-
108
- return value
109
-
110
- @model_validator(mode="after")
111
- def check_parameters(self: Self) -> Self:
112
- """Validate model by instantiating the model with the parameters.
113
-
114
- Returns
115
- -------
116
- Self
117
- The validated model.
118
- """
119
- # instantiate model
120
- try:
121
- get_custom_model(self.name)(**self.model_dump())
122
- except Exception as e:
123
- raise ValueError(
124
- f"while passing parameters to the model {e}. Verify that all "
125
- f"mandatory parameters are provided, and that either the {e} accepts "
126
- f"*args and **kwargs in its __init__() method, or that no additional"
127
- f"parameter is provided. Trace: "
128
- f"filename: {inspect.trace()[-1].filename}, function: "
129
- f"{inspect.trace()[-1].function}, line: {inspect.trace()[-1].lineno}"
130
- ) from None
131
-
132
- return self
133
-
134
- def __str__(self) -> str:
135
- """Pretty string representing the configuration.
136
-
137
- Returns
138
- -------
139
- str
140
- Pretty string.
141
- """
142
- return pformat(self.model_dump())
143
-
144
- def model_dump(self, **kwargs: Any) -> dict[str, Any]:
145
- """Dump the model configuration.
146
-
147
- Parameters
148
- ----------
149
- **kwargs : Any
150
- Additional keyword arguments from Pydantic BaseModel model_dump method.
151
-
152
- Returns
153
- -------
154
- dict[str, Any]
155
- Model configuration.
156
- """
157
- model_dict = super().model_dump()
158
-
159
- # remove the name key
160
- model_dict.pop("name")
161
-
162
- return model_dict
@@ -1,103 +0,0 @@
1
- """Custom model registration utilities."""
2
-
3
- from typing import Callable
4
-
5
- from torch.nn import Module
6
-
7
- CUSTOM_MODELS = {} # dictionary of custom models {"name": __class__}
8
-
9
-
10
- def register_model(name: str) -> Callable:
11
- """Decorator used to register a torch.nn.Module class with a given `name`.
12
-
13
- Parameters
14
- ----------
15
- name : str
16
- Name of the model.
17
-
18
- Returns
19
- -------
20
- Callable
21
- Function allowing to instantiate the wrapped Module class.
22
-
23
- Raises
24
- ------
25
- ValueError
26
- If a model is already registered with that name.
27
-
28
- Examples
29
- --------
30
- ```python
31
- @register_model(name="linear")
32
- class LinearModel(nn.Module):
33
- def __init__(self, in_features, out_features):
34
- super().__init__()
35
-
36
- self.weight = nn.Parameter(ones(in_features, out_features))
37
- self.bias = nn.Parameter(ones(out_features))
38
-
39
- def forward(self, input):
40
- return (input @ self.weight) + self.bias
41
- ```
42
- """
43
- if name is None or name == "":
44
- raise ValueError("Model name cannot be empty.")
45
-
46
- if name in CUSTOM_MODELS:
47
- raise ValueError(
48
- f"Model {name} already exists. Choose a different name or run "
49
- f"`clear_custom_models()` to empty the registry."
50
- )
51
-
52
- def add_custom_model(model: Module) -> Module:
53
- """Add a custom model to the registry and return it.
54
-
55
- Parameters
56
- ----------
57
- model : Module
58
- Module class to register.
59
-
60
- Returns
61
- -------
62
- Module
63
- The registered model.
64
- """
65
- # add model to the registry
66
- CUSTOM_MODELS[name] = model
67
-
68
- return model
69
-
70
- return add_custom_model
71
-
72
-
73
- def get_custom_model(name: str) -> Module:
74
- """Get the custom model corresponding to `name` from the registry.
75
-
76
- Parameters
77
- ----------
78
- name : str
79
- Name of the model to retrieve.
80
-
81
- Returns
82
- -------
83
- Module
84
- The requested model.
85
-
86
- Raises
87
- ------
88
- ValueError
89
- If the model is not registered.
90
- """
91
- if name not in CUSTOM_MODELS:
92
- raise ValueError(
93
- f"Model {name} is unknown. Have you registered it using "
94
- f'@register_model("{name}") as decorator?'
95
- )
96
-
97
- return CUSTOM_MODELS[name]
98
-
99
-
100
- def clear_custom_models() -> None:
101
- """Clear the custom models registry."""
102
- # clear dictionary
103
- CUSTOM_MODELS.clear()