careamics 0.0.5__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +17 -2
- careamics/careamist.py +4 -3
- careamics/cli/conf.py +1 -2
- careamics/cli/main.py +1 -2
- careamics/cli/utils.py +3 -3
- careamics/config/__init__.py +47 -25
- careamics/config/algorithms/__init__.py +15 -0
- careamics/config/algorithms/care_algorithm_model.py +50 -0
- careamics/config/algorithms/n2n_algorithm_model.py +42 -0
- careamics/config/algorithms/n2v_algorithm_model.py +35 -0
- careamics/config/algorithms/unet_algorithm_model.py +88 -0
- careamics/config/{vae_algorithm_model.py → algorithms/vae_algorithm_model.py} +14 -12
- careamics/config/architectures/__init__.py +1 -11
- careamics/config/architectures/architecture_model.py +3 -3
- careamics/config/architectures/lvae_model.py +6 -1
- careamics/config/architectures/unet_model.py +1 -0
- careamics/config/care_configuration.py +100 -0
- careamics/config/configuration.py +354 -0
- careamics/config/{configuration_factory.py → configuration_factories.py} +103 -36
- careamics/config/configuration_io.py +85 -0
- careamics/config/data/__init__.py +10 -0
- careamics/config/{data_model.py → data/data_model.py} +58 -198
- careamics/config/data/n2v_data_model.py +193 -0
- careamics/config/likelihood_model.py +1 -2
- careamics/config/n2n_configuration.py +101 -0
- careamics/config/n2v_configuration.py +266 -0
- careamics/config/nm_model.py +1 -2
- careamics/config/support/__init__.py +7 -7
- careamics/config/support/supported_algorithms.py +0 -3
- careamics/config/support/supported_architectures.py +0 -4
- careamics/config/transformations/__init__.py +10 -4
- careamics/config/transformations/transform_model.py +3 -3
- careamics/config/transformations/transform_unions.py +42 -0
- careamics/config/validators/validator_utils.py +3 -3
- careamics/dataset/__init__.py +2 -2
- careamics/dataset/dataset_utils/__init__.py +3 -3
- careamics/dataset/dataset_utils/dataset_utils.py +4 -6
- careamics/dataset/dataset_utils/file_utils.py +9 -9
- careamics/dataset/dataset_utils/iterate_over_files.py +4 -3
- careamics/dataset/in_memory_dataset.py +11 -12
- careamics/dataset/iterable_dataset.py +4 -4
- careamics/dataset/iterable_pred_dataset.py +2 -1
- careamics/dataset/iterable_tiled_pred_dataset.py +2 -1
- careamics/dataset/patching/random_patching.py +11 -10
- careamics/dataset/patching/sequential_patching.py +26 -26
- careamics/dataset/patching/validate_patch_dimension.py +3 -3
- careamics/dataset/tiling/__init__.py +2 -2
- careamics/dataset/tiling/collate_tiles.py +3 -3
- careamics/dataset/tiling/lvae_tiled_patching.py +2 -1
- careamics/dataset/tiling/tiled_patching.py +11 -10
- careamics/file_io/__init__.py +5 -5
- careamics/file_io/read/__init__.py +1 -1
- careamics/file_io/read/get_func.py +2 -2
- careamics/file_io/write/__init__.py +2 -2
- careamics/lightning/__init__.py +5 -5
- careamics/lightning/callbacks/__init__.py +1 -1
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +3 -3
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +2 -1
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +2 -1
- careamics/lightning/callbacks/progress_bar_callback.py +2 -2
- careamics/lightning/lightning_module.py +11 -7
- careamics/lightning/train_data_module.py +26 -26
- careamics/losses/__init__.py +3 -3
- careamics/model_io/__init__.py +1 -1
- careamics/model_io/bioimage/__init__.py +1 -1
- careamics/model_io/bioimage/_readme_factory.py +1 -1
- careamics/model_io/bioimage/model_description.py +17 -17
- careamics/model_io/bmz_io.py +6 -17
- careamics/model_io/model_io_utils.py +9 -9
- careamics/models/layers.py +16 -16
- careamics/models/lvae/lvae.py +0 -3
- careamics/models/model_factory.py +2 -15
- careamics/models/unet.py +8 -8
- careamics/prediction_utils/__init__.py +1 -1
- careamics/prediction_utils/prediction_outputs.py +15 -15
- careamics/prediction_utils/stitch_prediction.py +6 -6
- careamics/transforms/__init__.py +5 -5
- careamics/transforms/compose.py +13 -13
- careamics/transforms/n2v_manipulate.py +3 -3
- careamics/transforms/pixel_manipulation.py +9 -9
- careamics/transforms/xy_random_rotate90.py +4 -4
- careamics/utils/__init__.py +5 -5
- careamics/utils/context.py +2 -1
- careamics/utils/logging.py +11 -10
- careamics/utils/torch_utils.py +7 -7
- {careamics-0.0.5.dist-info → careamics-0.0.6.dist-info}/METADATA +11 -11
- {careamics-0.0.5.dist-info → careamics-0.0.6.dist-info}/RECORD +90 -85
- careamics/config/architectures/custom_model.py +0 -162
- careamics/config/architectures/register_model.py +0 -103
- careamics/config/configuration_model.py +0 -603
- careamics/config/fcn_algorithm_model.py +0 -152
- careamics/config/references/__init__.py +0 -45
- careamics/config/references/algorithm_descriptions.py +0 -132
- careamics/config/references/references.py +0 -39
- careamics/config/transformations/transform_union.py +0 -20
- {careamics-0.0.5.dist-info → careamics-0.0.6.dist-info}/WHEEL +0 -0
- {careamics-0.0.5.dist-info → careamics-0.0.6.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.5.dist-info → careamics-0.0.6.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,38 +1,43 @@
|
|
|
1
|
-
careamics/__init__.py,sha256=
|
|
2
|
-
careamics/careamist.py,sha256=
|
|
1
|
+
careamics/__init__.py,sha256=WF2JpQmC-MmuSB0L81XRo67NwaN_0qjyywcpRlbVJVE,569
|
|
2
|
+
careamics/careamist.py,sha256=rakSaDSGRR0Pr0o1s8ejwIrjWknOQs4DgljukqfyWu0,37635
|
|
3
3
|
careamics/conftest.py,sha256=Od4WcaaP0UP-XUMrFr_oo4e6c2hi_RvNbuaRTopwlmI,911
|
|
4
4
|
careamics/py.typed,sha256=esB4cHc6c07uVkGtqf8at7ttEnprwRxwk8obY8Qumq4,187
|
|
5
5
|
careamics/cli/__init__.py,sha256=LbM9bVtU1dy-khmdiIDXwvKy2v8wPBCEUuWqV_8rosA,106
|
|
6
|
-
careamics/cli/conf.py,sha256=
|
|
7
|
-
careamics/cli/main.py,sha256=
|
|
8
|
-
careamics/cli/utils.py,sha256=
|
|
9
|
-
careamics/config/__init__.py,sha256=
|
|
6
|
+
careamics/cli/conf.py,sha256=oixGRZNylW-NTM_rkDtQkSRw8KUYwtmUC_hK5BEeLnA,13074
|
|
7
|
+
careamics/cli/main.py,sha256=S4B3c1ZN-OQK0l2_W42CaW0KmF_Pe_y4pKgn_UOuyDg,6564
|
|
8
|
+
careamics/cli/utils.py,sha256=q_dmG7lxg_FT62qX9fPilIWL1M8ibhLnnhUKqa4knPI,660
|
|
9
|
+
careamics/config/__init__.py,sha256=c5FXtcQrtROHdLRl4cHpKo6V4_E4yr6HxYgNwqH9CHg,1917
|
|
10
10
|
careamics/config/callback_model.py,sha256=EeYHqpMIPQwyNxLRzzX32Uncl5mZuB1bJO76RHpNymg,4555
|
|
11
|
-
careamics/config/
|
|
12
|
-
careamics/config/
|
|
13
|
-
careamics/config/
|
|
14
|
-
careamics/config/
|
|
11
|
+
careamics/config/care_configuration.py,sha256=hSfNJ-dooHm4ujG6Q3Hawr8zDeYy1HNiebtM7gxsh7s,2714
|
|
12
|
+
careamics/config/configuration.py,sha256=KmLeXHkFhQTrcru1erhfVf3tHvQdoi12ls_u254rtDw,11114
|
|
13
|
+
careamics/config/configuration_factories.py,sha256=-kdoLQ0dkV4xlcUx_odrtMsvL4NqGYg5-4HVl3PqY4c,32423
|
|
14
|
+
careamics/config/configuration_io.py,sha256=ks9R8lRCBY_m0sdy1k4ZWFPKEFSp7K9X47jCG4d0FY4,2353
|
|
15
15
|
careamics/config/inference_model.py,sha256=UE_-ZmCX6LFCbDBOwyGnvuAboF_JNX2m2LcF0WiwgCI,6961
|
|
16
|
-
careamics/config/likelihood_model.py,sha256=
|
|
16
|
+
careamics/config/likelihood_model.py,sha256=VorUtc0_-xIWNxwVrd1kBba-003ICdVMtxpcDCxH4Io,2259
|
|
17
17
|
careamics/config/loss_model.py,sha256=yYcUBS90Qyon1MxeaHiVP3dJHPJFC0GUvWKGcAb3IHk,2036
|
|
18
|
-
careamics/config/
|
|
18
|
+
careamics/config/n2n_configuration.py,sha256=QoHnEak0LiG6rBIXnmx61Chkz-Q0jr2Qt9I0NnmgZo4,2667
|
|
19
|
+
careamics/config/n2v_configuration.py,sha256=5MagUFaVZ-9VGIiaKD-4RRkFpJen2imMoR3Lu07BD14,8489
|
|
20
|
+
careamics/config/nm_model.py,sha256=5dAhDBLa4WPfKaNEK6ATNsSUwtlH8u8gYweEA4gZP6g,4758
|
|
19
21
|
careamics/config/optimizer_models.py,sha256=OWpTydRBBR8wt_af1mZHNNwvL_RtnRFopAOdgjzLo30,5750
|
|
20
22
|
careamics/config/tile_information.py,sha256=c-_xrVPOgcnjiEzQ-9A_GhNPamObkMANbeHaRP29R-4,2059
|
|
21
23
|
careamics/config/training_model.py,sha256=67_ipo_-LxhT4-WqAs40Sg8PjU--my43Qn3BhjvlXxM,3212
|
|
22
|
-
careamics/config/
|
|
23
|
-
careamics/config/
|
|
24
|
-
careamics/config/
|
|
25
|
-
careamics/config/
|
|
26
|
-
careamics/config/
|
|
27
|
-
careamics/config/
|
|
28
|
-
careamics/config/architectures/
|
|
29
|
-
careamics/config/
|
|
30
|
-
careamics/config/
|
|
31
|
-
careamics/config/
|
|
32
|
-
careamics/config/
|
|
24
|
+
careamics/config/algorithms/__init__.py,sha256=on5D6zBO9Lu-Tf5To8xpF6owIFqdN7RSmZdpyXDOaNw,404
|
|
25
|
+
careamics/config/algorithms/care_algorithm_model.py,sha256=mdMDc0Vr0V7LZmksdjNDlRx3ofwy57IMXygY_vg6hPY,1175
|
|
26
|
+
careamics/config/algorithms/n2n_algorithm_model.py,sha256=FO04TRecvMPikz_s2WupwODzi4g3x84g1HXRPgMAZWo,1015
|
|
27
|
+
careamics/config/algorithms/n2v_algorithm_model.py,sha256=TBAUKeFxdRtomYhFrVqfq31hpOiNQO2YOogYzVVFknM,919
|
|
28
|
+
careamics/config/algorithms/unet_algorithm_model.py,sha256=OaBFVlhsb9YhF3f2x1ImazvfnZ4_DPWvYWihRwurkeg,2587
|
|
29
|
+
careamics/config/algorithms/vae_algorithm_model.py,sha256=1ZrShUGHA7zbjSiCQwhUSv7l7Hr7MbpcLOw8ucM27p8,4680
|
|
30
|
+
careamics/config/architectures/__init__.py,sha256=lYUz56R7LDqVQWQDLkLgJL8VtOyxc_ege1r4bXGEBqA,220
|
|
31
|
+
careamics/config/architectures/architecture_model.py,sha256=gXn4gdLrQP3bmTQxIhzkEHYlodPaIp6LI-kwZl23W-Y,911
|
|
32
|
+
careamics/config/architectures/lvae_model.py,sha256=lwOiJYNUqPrZFl9SpPLYon77EiRbe2eI7pmpx45rO78,7606
|
|
33
|
+
careamics/config/architectures/unet_model.py,sha256=HJJWf-wuYTv8KXMwikdjeB8htg1QOt0IyQbMuBt1LUI,3556
|
|
34
|
+
careamics/config/data/__init__.py,sha256=ijFNSRKkKVF8fw6ym8kXq_wNhdwkWXxuNj2XKtD3KjE,218
|
|
35
|
+
careamics/config/data/data_model.py,sha256=lSAIjpxuLaw1GMJ0rOHOZghYBwqWwlZf5sCphwAKMUc,11111
|
|
36
|
+
careamics/config/data/n2v_data_model.py,sha256=-n5cncmcrd4-KUSmEk8rXAONMC7sfbYfGmshVNuySCU,5915
|
|
37
|
+
careamics/config/support/__init__.py,sha256=ktWvxbTkRXnQPS_N84l9E2B5kTZVdd64SIjsJIQKB-k,1041
|
|
33
38
|
careamics/config/support/supported_activations.py,sha256=CqOWoziIK5jZZXJO7G7cGg3TTid1POqv8FXqxjXxyME,535
|
|
34
|
-
careamics/config/support/supported_algorithms.py,sha256=
|
|
35
|
-
careamics/config/support/supported_architectures.py,sha256=
|
|
39
|
+
careamics/config/support/supported_algorithms.py,sha256=ivLECrplTLjSUsa7AyT1n2aOo_WM3sMJJlW3OOT1fPk,846
|
|
40
|
+
careamics/config/support/supported_architectures.py,sha256=pOxvHOAIUkc7HCO0IIg4K22h-Ti5ErtcIkGOjN-zh1s,340
|
|
36
41
|
careamics/config/support/supported_data.py,sha256=T_mDiWLFMVji_EpjBABUObAJcnv-XBnqp9XUZP37Tdk,2902
|
|
37
42
|
careamics/config/support/supported_loggers.py,sha256=ubSOkGoYabGbm_jmyc1R3eFcvcP-sHmuyiBi_d3_wLg,197
|
|
38
43
|
careamics/config/support/supported_losses.py,sha256=2x5sZuxRbWJzodoL35I1mMYUUDMzk8UFiFdbyPwbJ4E,583
|
|
@@ -40,58 +45,58 @@ careamics/config/support/supported_optimizers.py,sha256=_2XmwzYENB6xpTedyWHUdWuG
|
|
|
40
45
|
careamics/config/support/supported_pixel_manipulations.py,sha256=rFiktUlvoFU7s1NAKEMqsXOzLw5eaw9GtCKUznvq6xc,432
|
|
41
46
|
careamics/config/support/supported_struct_axis.py,sha256=alZMA5Y-BpDymLPUEd1zqVY0xMkgl9Rv1d4ujED6sco,424
|
|
42
47
|
careamics/config/support/supported_transforms.py,sha256=ODvmoTywvJWG_5-SJJZu-X1FNtKGhkNWQc-t26IFZWI,311
|
|
43
|
-
careamics/config/transformations/__init__.py,sha256=
|
|
48
|
+
careamics/config/transformations/__init__.py,sha256=jMTUX15n8ZF4Nc9gQp-qbXfhj-iEsEa65lzNbHwzyfY,631
|
|
44
49
|
careamics/config/transformations/n2v_manipulate_model.py,sha256=Mdxc4J3vxe_dM2CIhmTwwGOIirQvrQXLoa2vRsTzoYI,1855
|
|
45
50
|
careamics/config/transformations/normalize_model.py,sha256=1Rkk6IkF-7ytGU6HSzP-TpOi4RRWiQJ6fOd8zammXcg,1936
|
|
46
|
-
careamics/config/transformations/transform_model.py,sha256=
|
|
47
|
-
careamics/config/transformations/
|
|
51
|
+
careamics/config/transformations/transform_model.py,sha256=6UVbXnxm-LLZOQQ-ZBwWwgmS_99DiBuERLfMxrta3-8,990
|
|
52
|
+
careamics/config/transformations/transform_unions.py,sha256=uqlI8Nm827bKfMbDQLVhKFtT9e7TJ_zIYDBdHlOuQ1I,1137
|
|
48
53
|
careamics/config/transformations/xy_flip_model.py,sha256=zU-uZ1b1zNZWckbho3onN-B7BHKhN7jbgbNZyRQhv2s,1025
|
|
49
54
|
careamics/config/transformations/xy_random_rotate90_model.py,sha256=6sYKmtCLvz0SV1qZgBSHUTH-CUjwvHnohq1HyPntbyE,894
|
|
50
55
|
careamics/config/validators/__init__.py,sha256=iv0nVI0W7j9DxFPwh0DjRCzM9P8oLQn4Gwi5rfuFrrI,180
|
|
51
|
-
careamics/config/validators/validator_utils.py,sha256=
|
|
52
|
-
careamics/dataset/__init__.py,sha256=
|
|
53
|
-
careamics/dataset/in_memory_dataset.py,sha256=
|
|
56
|
+
careamics/config/validators/validator_utils.py,sha256=NVkEOr5AQK4JXWNtmgeQgAaJOyieJNb5PHCjlcqNeew,2611
|
|
57
|
+
careamics/dataset/__init__.py,sha256=31vop67zbtGesENEIig-LLw1q2lCydMFc_YWgfK2Yt4,547
|
|
58
|
+
careamics/dataset/in_memory_dataset.py,sha256=MV_Vf4siIP-g7VKhxN4rU7MZXpaHKvfwr8ZXqk44Qhs,9958
|
|
54
59
|
careamics/dataset/in_memory_pred_dataset.py,sha256=VvwW5D8TjgO_kR8eZinP-9qepSiI6ZsUN7FZ0Rvc8Bs,2161
|
|
55
60
|
careamics/dataset/in_memory_tiled_pred_dataset.py,sha256=DANmlnlV1ysXKdwGvmJoOYKcjlgoMhnSGSDRpeK79ZA,3552
|
|
56
|
-
careamics/dataset/iterable_dataset.py,sha256=
|
|
57
|
-
careamics/dataset/iterable_pred_dataset.py,sha256=
|
|
58
|
-
careamics/dataset/iterable_tiled_pred_dataset.py,sha256=
|
|
61
|
+
careamics/dataset/iterable_dataset.py,sha256=pqtm-AWhDbuZTnXf0roAHWVxGPRTekAzVcJHLzaSyFU,9797
|
|
62
|
+
careamics/dataset/iterable_pred_dataset.py,sha256=jee4b8bZOyvSS5qfIsb6Jkk1EV_MKEU2SyZ0m7p0p9k,3767
|
|
63
|
+
careamics/dataset/iterable_tiled_pred_dataset.py,sha256=2j_kLMB6DfSKXPszZPYgsB08TVgcf1V5HY_kZVozrFM,4560
|
|
59
64
|
careamics/dataset/zarr_dataset.py,sha256=lojnK5bhiF1vyjuPtWXBrZ9sy5fT_rBvZJbbbnE-H_I,5665
|
|
60
|
-
careamics/dataset/dataset_utils/__init__.py,sha256=
|
|
61
|
-
careamics/dataset/dataset_utils/dataset_utils.py,sha256=
|
|
62
|
-
careamics/dataset/dataset_utils/file_utils.py,sha256=
|
|
63
|
-
careamics/dataset/dataset_utils/iterate_over_files.py,sha256=
|
|
65
|
+
careamics/dataset/dataset_utils/__init__.py,sha256=MJ3xriL6R4ZtmzbvLsASUWLb85Hk5AdeRaYnHpNELJQ,507
|
|
66
|
+
careamics/dataset/dataset_utils/dataset_utils.py,sha256=X83DzaOWmHdl4eOPac2IQJH3bPA43RVq0hPrFrzvIXQ,2630
|
|
67
|
+
careamics/dataset/dataset_utils/file_utils.py,sha256=ru6AtQ9LCmo6raN1-GnJEN4UyP1PbmSdR9MEys3CuHo,4094
|
|
68
|
+
careamics/dataset/dataset_utils/iterate_over_files.py,sha256=Jun35Qn9XevHOb_DixYBMHDAOykLmiwciA5Q2MzSUK8,2912
|
|
64
69
|
careamics/dataset/dataset_utils/running_stats.py,sha256=kWorioMH4S5uZj2cvUpjHB6cIUhMFa1XXwDQrrKIWdI,5752
|
|
65
70
|
careamics/dataset/patching/__init__.py,sha256=7-s12oUAZNlMOwSkxSwbD7vojQINWYFzn_4qIJ87WBg,37
|
|
66
71
|
careamics/dataset/patching/patching.py,sha256=deAxY34Iz-mguBlHQ-5EO4vRhPpR9I3LQ9onV1K_KqA,8858
|
|
67
|
-
careamics/dataset/patching/random_patching.py,sha256=
|
|
68
|
-
careamics/dataset/patching/sequential_patching.py,sha256=
|
|
69
|
-
careamics/dataset/patching/validate_patch_dimension.py,sha256=
|
|
70
|
-
careamics/dataset/tiling/__init__.py,sha256=
|
|
71
|
-
careamics/dataset/tiling/collate_tiles.py,sha256=
|
|
72
|
-
careamics/dataset/tiling/lvae_tiled_patching.py,sha256=
|
|
73
|
-
careamics/dataset/tiling/tiled_patching.py,sha256=
|
|
74
|
-
careamics/file_io/__init__.py,sha256=
|
|
75
|
-
careamics/file_io/read/__init__.py,sha256=
|
|
76
|
-
careamics/file_io/read/get_func.py,sha256=
|
|
72
|
+
careamics/dataset/patching/random_patching.py,sha256=gm1jxye9yvHbdijLzCtDSzRU_9j110GRLMnJaUwLAHQ,6487
|
|
73
|
+
careamics/dataset/patching/sequential_patching.py,sha256=4F5E1Ta0M5kFXGwI2-QXRxeOx0CyUwbFaB5awkMCN_Q,5890
|
|
74
|
+
careamics/dataset/patching/validate_patch_dimension.py,sha256=mC2bZWBpU44NEvXxEfR7ULUKwWPuZPjmBWpHYJxNDWc,2121
|
|
75
|
+
careamics/dataset/tiling/__init__.py,sha256=aW_AMB9rzm0VmooUpjcyqv6sQP69RlPQMEdP2sVjdz8,190
|
|
76
|
+
careamics/dataset/tiling/collate_tiles.py,sha256=XK0BsDQE7XwIwmOoCHJIpVC3kqjSN6nDhrJ4POVeHS8,965
|
|
77
|
+
careamics/dataset/tiling/lvae_tiled_patching.py,sha256=LYEEdjKuKaxIGFtOkhfpsE7hruBnIsD5HcW9aVH6WHI,13019
|
|
78
|
+
careamics/dataset/tiling/tiled_patching.py,sha256=6vxsqlccUqIl4Ys92JWIPs0Kn95VzaHoAYMSGcp2dh8,5956
|
|
79
|
+
careamics/file_io/__init__.py,sha256=vgMI77X820VOWywAEW5W20FXfmbqBzx4V63D3V3_HhI,334
|
|
80
|
+
careamics/file_io/read/__init__.py,sha256=wf8O_o80ghrlWQ-RGEuSqcc2LU55P1B-oxTacDToygo,259
|
|
81
|
+
careamics/file_io/read/get_func.py,sha256=O_pdymjh2mc-JZ1je3ZnPAcsHc7Je3a005AMgAa0xuw,1388
|
|
77
82
|
careamics/file_io/read/tiff.py,sha256=UMofW33rvByK9B1zYGhSrWAiAA3uQUV3OVK7cq9d0gQ,1359
|
|
78
83
|
careamics/file_io/read/zarr.py,sha256=2jzREAnJDQSv0qmsL-v00BxmiZ_sp0ijq667LZSQ_hY,1685
|
|
79
|
-
careamics/file_io/write/__init__.py,sha256=
|
|
84
|
+
careamics/file_io/write/__init__.py,sha256=CUt33cRjG9hm18L9a7XqaUKWQ_3xiuQ9ztz4Ab7RYG0,283
|
|
80
85
|
careamics/file_io/write/get_func.py,sha256=hyGHe1RX-lfa9QFAnwRCz_gS0NRiRnXEtg4Bdeh2Esc,1627
|
|
81
86
|
careamics/file_io/write/tiff.py,sha256=tBGIgl-I1sMyBivgx-dOTBykXBODkgwPH8MT3_4KAE8,1050
|
|
82
|
-
careamics/lightning/__init__.py,sha256=
|
|
83
|
-
careamics/lightning/lightning_module.py,sha256=
|
|
87
|
+
careamics/lightning/__init__.py,sha256=ATCVAGnX08Ik4TxbIv0-cXb52UinR42JgvZh_GIMSpc,588
|
|
88
|
+
careamics/lightning/lightning_module.py,sha256=3nzRcjv7dlWmWjoWmGaBV7D6d698xMp3XO7fpIrZhwA,22630
|
|
84
89
|
careamics/lightning/predict_data_module.py,sha256=JNwujK6QwObSx6P25ghpGl2f2gGT3KVgYMTlonZzH20,12745
|
|
85
|
-
careamics/lightning/train_data_module.py,sha256=
|
|
86
|
-
careamics/lightning/callbacks/__init__.py,sha256=
|
|
90
|
+
careamics/lightning/train_data_module.py,sha256=7fti07y2TTV_Airl-D9FRTJf7AmsFmErSKnWUGLQE58,28699
|
|
91
|
+
careamics/lightning/callbacks/__init__.py,sha256=eA5ltzYNzuO0uMEr1jG4wP01b0s29s5I03WGJ290qkw,312
|
|
87
92
|
careamics/lightning/callbacks/hyperparameters_callback.py,sha256=u45knOZHwoVHz6yYfrnERQuozT_SfZ1OrKP0QjeU4EM,1495
|
|
88
|
-
careamics/lightning/callbacks/progress_bar_callback.py,sha256=
|
|
89
|
-
careamics/lightning/callbacks/prediction_writer_callback/__init__.py,sha256=
|
|
93
|
+
careamics/lightning/callbacks/progress_bar_callback.py,sha256=RilGAVUa90AlCLdooIGJF2cAcHAjGn24zKZiBmRRkwg,2438
|
|
94
|
+
careamics/lightning/callbacks/prediction_writer_callback/__init__.py,sha256=ZVf3vaSU_NjSjrKbI24H0kK9WAiP9oKXfhP670EaWMo,548
|
|
90
95
|
careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py,sha256=i4vGGiVLslafi-5iuvkAKzBgZ0BpwTTxSTo31oViFz4,1480
|
|
91
|
-
careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py,sha256=
|
|
92
|
-
careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py,sha256=
|
|
96
|
+
careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py,sha256=8HHUSKcG7G0FSCVPnpGQHLfpara5mnKAwsiiyWp2wzo,8210
|
|
97
|
+
careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py,sha256=lxsLjLskRpYnzdyWCdOICUJxF9YzuUi1RH0LJnOCVgo,12594
|
|
93
98
|
careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py,sha256=F1IpbNNgkv5eK8Xpqp7wqv2lsqEdP1wMRlBL7RBn93U,7114
|
|
94
|
-
careamics/losses/__init__.py,sha256=
|
|
99
|
+
careamics/losses/__init__.py,sha256=nSWbkBcFhkyUkIT2wVcULqpieyY2Oro39NXZTtfQpXo,351
|
|
95
100
|
careamics/losses/loss_factory.py,sha256=oPacrkwiabsmiW_r--IxX-XPRbzezZUvOuWKbUw5LiI,1518
|
|
96
101
|
careamics/losses/fcn/__init__.py,sha256=kf92MKFGHr6upiztZVgWwtGPf734DZyub92Rn8uEq8o,18
|
|
97
102
|
careamics/losses/fcn/losses.py,sha256=NdOz29hzJ7D26p13q-g0NWoYwNauIWrP2xWww6YPbB8,2360
|
|
@@ -117,55 +122,55 @@ careamics/lvae_training/dataset/utils/data_utils.py,sha256=8PvRPqSbYHPCl87cycZHx
|
|
|
117
122
|
careamics/lvae_training/dataset/utils/empty_patch_fetcher.py,sha256=OFjeqhZ6vFULJsF5tnoByhEhE8aLHujFToU_yyqMCP4,2266
|
|
118
123
|
careamics/lvae_training/dataset/utils/index_manager.py,sha256=Gt1I-7lBaQDBgqguOmofAFDdAsJQfz7ktvq4_I80F9c,10084
|
|
119
124
|
careamics/lvae_training/dataset/utils/index_switcher.py,sha256=ZoMi8LsaIkm8MFqIFaxN4oQGyzCcwOlCom8SYNus15E,6716
|
|
120
|
-
careamics/model_io/__init__.py,sha256=
|
|
121
|
-
careamics/model_io/bmz_io.py,sha256=
|
|
122
|
-
careamics/model_io/model_io_utils.py,sha256=
|
|
123
|
-
careamics/model_io/bioimage/__init__.py,sha256=
|
|
124
|
-
careamics/model_io/bioimage/_readme_factory.py,sha256=
|
|
125
|
+
careamics/model_io/__init__.py,sha256=khMIkk107LL5JGze0OVfl5Lfi14R3_e4W21tW0iJ1kE,155
|
|
126
|
+
careamics/model_io/bmz_io.py,sha256=0fL0cz_80nsu_Jgk-ImXKTNbn_WSdrXB9flvLm2g3Is,7792
|
|
127
|
+
careamics/model_io/model_io_utils.py,sha256=HCSoNvaOo55kI7teZleN57riqNK2fLq77zMadBKXCyc,2777
|
|
128
|
+
careamics/model_io/bioimage/__init__.py,sha256=dKm-UluVwa_7EHQB0ukx-Qk8JVWtuT-OUinY8hE9EIw,298
|
|
129
|
+
careamics/model_io/bioimage/_readme_factory.py,sha256=sRfHbuwhfYmpwsG0keXFCSK3qCS4VI4GQhX6OhSKLjY,3440
|
|
125
130
|
careamics/model_io/bioimage/bioimage_utils.py,sha256=YVr75SDfafOiuYGonPbcsO-xVS0wI1WkmWQQZx6DXYQ,1246
|
|
126
131
|
careamics/model_io/bioimage/cover_factory.py,sha256=8URrpEfJvdHBJeSrh5H2IQHSUybsTyAOR3_A-YYAAlw,4583
|
|
127
|
-
careamics/model_io/bioimage/model_description.py,sha256=
|
|
132
|
+
careamics/model_io/bioimage/model_description.py,sha256=HwjWT77nIt6LkLvIBRxwgPOsL0ffrCMtDUjim39qgqE,10071
|
|
128
133
|
careamics/models/__init__.py,sha256=Xui2BLJd1I2r_E3Sj24fJALFTi2FGtfNscUWj_0c9Hk,93
|
|
129
134
|
careamics/models/activation.py,sha256=nu3sDgd7Lsyw8rvmUwxNN-7SM09cEMfxZ9DRDzdSKns,1049
|
|
130
|
-
careamics/models/layers.py,sha256=
|
|
131
|
-
careamics/models/model_factory.py,sha256=
|
|
132
|
-
careamics/models/unet.py,sha256=
|
|
135
|
+
careamics/models/layers.py,sha256=tpsxbolRWYycZGxS3hKlDRtMf6HpNdZs98uwx5K8lls,13757
|
|
136
|
+
careamics/models/model_factory.py,sha256=GWbouERvEfHj_BrpKHYgrPAj8dpdoh1R-X--Jt09N1Q,1370
|
|
137
|
+
careamics/models/unet.py,sha256=9m8GxsTXX9c0mC-eVe2ZXQn2afose1CG6Z8vIhELb7I,14308
|
|
133
138
|
careamics/models/lvae/__init__.py,sha256=6dT6uqgT__V08EjoTGxXguTbTkySZmByS9J2Bj6WWLM,53
|
|
134
139
|
careamics/models/lvae/layers.py,sha256=UPxZiZjgnBnPs_wdSzcP-_s17MEw4P1CoIZzn4OdUA0,57944
|
|
135
140
|
careamics/models/lvae/likelihoods.py,sha256=SHsaZZTjGg8TO6JkvjdQOZtu8CCOvlSxtqRBZhrGXKk,12098
|
|
136
|
-
careamics/models/lvae/lvae.py,sha256=
|
|
141
|
+
careamics/models/lvae/lvae.py,sha256=AWll857rAnlvjTom5CX0CpbEcf-ci_Icoz5wFKVLu5E,34220
|
|
137
142
|
careamics/models/lvae/noise_models.py,sha256=dFWM9DDJ7qzsyiT8sDWR8THbEiZmu3XnW5UzGJl7Mck,21834
|
|
138
143
|
careamics/models/lvae/stochastic.py,sha256=019M6BBR6GtYjUVF6pcOTOsfEAYeg0vclz55V6Fl4yY,16588
|
|
139
144
|
careamics/models/lvae/utils.py,sha256=EE3paHu3vhCaqfOrGypzUsImZJO94uBhx8q6kZ-R36o,11516
|
|
140
|
-
careamics/prediction_utils/__init__.py,sha256=
|
|
145
|
+
careamics/prediction_utils/__init__.py,sha256=k5hsPGY8FOkwIT0fQgrUz7fVCH2NlwuOdZiISdXjEWg,270
|
|
141
146
|
careamics/prediction_utils/lvae_prediction.py,sha256=ZwPFCSeUGsULIMoMQWRYKHfLFaDm7UKyGaUMVfSUqfs,6210
|
|
142
147
|
careamics/prediction_utils/lvae_tiling_manager.py,sha256=SI-JaJvLrKWBSHdm-FjcqWdbhlcflTRiKxYF7CSGzvA,13736
|
|
143
|
-
careamics/prediction_utils/prediction_outputs.py,sha256=
|
|
144
|
-
careamics/prediction_utils/stitch_prediction.py,sha256=
|
|
145
|
-
careamics/transforms/__init__.py,sha256=
|
|
146
|
-
careamics/transforms/compose.py,sha256=
|
|
147
|
-
careamics/transforms/n2v_manipulate.py,sha256=
|
|
148
|
+
careamics/prediction_utils/prediction_outputs.py,sha256=fw-bJ2szWJD7BgZlECmxy5sgeXGFJl4T8cRNzLR1aUQ,4069
|
|
149
|
+
careamics/prediction_utils/stitch_prediction.py,sha256=8YRW2rea-is5tYI0Q1bw3bpX7VMFmbpxSP_y6x9Yfug,3893
|
|
150
|
+
careamics/transforms/__init__.py,sha256=WtgpSFL_CJwpa47XzqS7bVXHPJ4qW0TamEymy_kgWQQ,483
|
|
151
|
+
careamics/transforms/compose.py,sha256=ZSVwKg3LT2PrwtSBKtkb6AHsVSSlZIdv9wTYl4To1s4,5682
|
|
152
|
+
careamics/transforms/n2v_manipulate.py,sha256=t9rtMbYV6P1IVp4yzuJfq5-giWyfGrxL8ZhzP29Pp8k,5686
|
|
148
153
|
careamics/transforms/normalize.py,sha256=fxs813ydCWrIzrxFzkbk1gW8OGSr0esQSrNUFSJuGL0,7715
|
|
149
|
-
careamics/transforms/pixel_manipulation.py,sha256=
|
|
154
|
+
careamics/transforms/pixel_manipulation.py,sha256=38NsxY8ARvz7GSNDKx5g67Hv5qciBzadiELAg5OcUSU,13355
|
|
150
155
|
careamics/transforms/struct_mask_parameters.py,sha256=jE29Li9sx3olaRnqYfJsSlKi2t0WQzJmCm9aCbIQEsA,421
|
|
151
156
|
careamics/transforms/transform.py,sha256=cEqc4ci8na70i-HIGYC7udRfVa8D_8OjdRVrr3txLvQ,464
|
|
152
157
|
careamics/transforms/tta.py,sha256=78S7Df9rLHmEVSQSI1qDcRrRJGauyG3oaIrXkckCkmw,2335
|
|
153
158
|
careamics/transforms/xy_flip.py,sha256=64BDo8bmAEwO1TNhbIYcUJPzzVmY5ZyNaSNmmGLkn0U,3842
|
|
154
|
-
careamics/transforms/xy_random_rotate90.py,sha256=
|
|
155
|
-
careamics/utils/__init__.py,sha256=
|
|
159
|
+
careamics/transforms/xy_random_rotate90.py,sha256=Kin42yaV4Z8lOwC9nN8gxK73rgnJ2MhCoHHPQmlSgvc,3185
|
|
160
|
+
careamics/utils/__init__.py,sha256=mLwBQ7wTL2EwDwL3NcX53EHPNklojU45Jcc728y4EWQ,402
|
|
156
161
|
careamics/utils/autocorrelation.py,sha256=M_WYzrEOQngc5iSXWar4S3-EOnK6DfYHPC2vVMeu_Bs,945
|
|
157
162
|
careamics/utils/base_enum.py,sha256=bz1D8mDx5V5hdnJ3WAzJXWHJTbgwAky5FprUt9F5cMA,1387
|
|
158
|
-
careamics/utils/context.py,sha256=
|
|
163
|
+
careamics/utils/context.py,sha256=SoTZfzG6fO4SDOGHOTL2Xlm1n1CSgb9B57GVhrEkFls,1436
|
|
159
164
|
careamics/utils/lightning_utils.py,sha256=DMMmqx-AlNtddBCqm8b_W3h09qUetz32OMPhdDieFwg,1769
|
|
160
|
-
careamics/utils/logging.py,sha256=
|
|
165
|
+
careamics/utils/logging.py,sha256=5U4VsQ4m4OajtirLH6qUjrM1CAc-oXeCsd6JyROjkWE,10337
|
|
161
166
|
careamics/utils/metrics.py,sha256=yAoCvrZ1kQx-kT9xdTBYz-oh0I52ef6uBnw8qgzpwn8,10318
|
|
162
167
|
careamics/utils/path_utils.py,sha256=8AugiG5DOmzgSnTCJI8vypXaPE0XhnR-9pzeiFUZ-0I,554
|
|
163
168
|
careamics/utils/ram.py,sha256=tksyn8dVX_iJXmrDZDGub32hFZWIaNxnMheO5G1p43I,244
|
|
164
169
|
careamics/utils/receptive_field.py,sha256=Y2h4c8S6glX3qcx5KHDmO17Kkuyey9voxfoXyqcAfiM,3296
|
|
165
170
|
careamics/utils/serializers.py,sha256=mILUhz75IMpGKnEzcYu9hlOPG8YIiIW09fk6eZM7Y8k,1427
|
|
166
|
-
careamics/utils/torch_utils.py,sha256=
|
|
167
|
-
careamics-0.0.
|
|
168
|
-
careamics-0.0.
|
|
169
|
-
careamics-0.0.
|
|
170
|
-
careamics-0.0.
|
|
171
|
-
careamics-0.0.
|
|
171
|
+
careamics/utils/torch_utils.py,sha256=_Cf3HdlIRl5hxfpUg9aofCSlcW7GSsIJxsbSORXko0U,3010
|
|
172
|
+
careamics-0.0.6.dist-info/METADATA,sha256=Hz6RQlh5szAGllr7baFnc86U_sSC2fCjvU4FhjmhYMk,3898
|
|
173
|
+
careamics-0.0.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
174
|
+
careamics-0.0.6.dist-info/entry_points.txt,sha256=2fSNVXJWDJgFLATVj7MkjFNvpl53amG8tUzC3jf7G1s,53
|
|
175
|
+
careamics-0.0.6.dist-info/licenses/LICENSE,sha256=6zdNW-k_xHRKYWUf9tDI_ZplUciFHyj0g16DYuZ2udw,1509
|
|
176
|
+
careamics-0.0.6.dist-info/RECORD,,
|
|
@@ -1,162 +0,0 @@
|
|
|
1
|
-
"""Custom architecture Pydantic model."""
|
|
2
|
-
|
|
3
|
-
from __future__ import annotations
|
|
4
|
-
|
|
5
|
-
import inspect
|
|
6
|
-
from pprint import pformat
|
|
7
|
-
from typing import Any, Literal
|
|
8
|
-
|
|
9
|
-
from pydantic import ConfigDict, field_validator, model_validator
|
|
10
|
-
from torch.nn import Module
|
|
11
|
-
from typing_extensions import Self
|
|
12
|
-
|
|
13
|
-
from .architecture_model import ArchitectureModel
|
|
14
|
-
from .register_model import get_custom_model
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
class CustomModel(ArchitectureModel):
|
|
18
|
-
"""Custom model configuration.
|
|
19
|
-
|
|
20
|
-
This Pydantic model allows storing parameters for a custom model. In order for the
|
|
21
|
-
model to be valid, the specific model needs to be registered using the
|
|
22
|
-
`register_model` decorator, and its name correctly passed to this model
|
|
23
|
-
configuration (see Examples).
|
|
24
|
-
|
|
25
|
-
Attributes
|
|
26
|
-
----------
|
|
27
|
-
architecture : Literal["custom"]
|
|
28
|
-
Discriminator for the custom model, must be set to "custom".
|
|
29
|
-
name : str
|
|
30
|
-
Name of the custom model.
|
|
31
|
-
parameters : CustomParametersModel
|
|
32
|
-
All parameters, required for the initialization of the torch module have to be
|
|
33
|
-
passed here.
|
|
34
|
-
|
|
35
|
-
Raises
|
|
36
|
-
------
|
|
37
|
-
ValueError
|
|
38
|
-
If the custom model `name` is unknown.
|
|
39
|
-
ValueError
|
|
40
|
-
If the custom model is not a torch Module subclass.
|
|
41
|
-
ValueError
|
|
42
|
-
If the custom model parameters are not valid.
|
|
43
|
-
|
|
44
|
-
Examples
|
|
45
|
-
--------
|
|
46
|
-
>>> from torch import nn, ones
|
|
47
|
-
>>> from careamics.config import CustomModel, register_model
|
|
48
|
-
>>> # Register a custom model
|
|
49
|
-
>>> @register_model(name="my_linear")
|
|
50
|
-
... class LinearModel(nn.Module):
|
|
51
|
-
... def __init__(self, in_features, out_features, *args, **kwargs):
|
|
52
|
-
... super().__init__()
|
|
53
|
-
... self.in_features = in_features
|
|
54
|
-
... self.out_features = out_features
|
|
55
|
-
... self.weight = nn.Parameter(ones(in_features, out_features))
|
|
56
|
-
... self.bias = nn.Parameter(ones(out_features))
|
|
57
|
-
... def forward(self, input):
|
|
58
|
-
... return (input @ self.weight) + self.bias
|
|
59
|
-
...
|
|
60
|
-
>>> # Create a configuration
|
|
61
|
-
>>> config_dict = {
|
|
62
|
-
... "architecture": "custom",
|
|
63
|
-
... "name": "my_linear",
|
|
64
|
-
... "in_features": 10,
|
|
65
|
-
... "out_features": 5,
|
|
66
|
-
... }
|
|
67
|
-
>>> config = CustomModel(**config_dict)
|
|
68
|
-
"""
|
|
69
|
-
|
|
70
|
-
# pydantic model config
|
|
71
|
-
model_config = ConfigDict(
|
|
72
|
-
extra="allow",
|
|
73
|
-
)
|
|
74
|
-
|
|
75
|
-
# discriminator used for choosing the pydantic model in Model
|
|
76
|
-
architecture: Literal["custom"]
|
|
77
|
-
"""Name of the architecture."""
|
|
78
|
-
|
|
79
|
-
name: str
|
|
80
|
-
"""Name of the custom model."""
|
|
81
|
-
|
|
82
|
-
@field_validator("name")
|
|
83
|
-
@classmethod
|
|
84
|
-
def custom_model_is_known(cls, value: str) -> str:
|
|
85
|
-
"""Check whether the custom model is known.
|
|
86
|
-
|
|
87
|
-
Parameters
|
|
88
|
-
----------
|
|
89
|
-
value : str
|
|
90
|
-
Name of the custom model as registered using the `@register_model`
|
|
91
|
-
decorator.
|
|
92
|
-
|
|
93
|
-
Returns
|
|
94
|
-
-------
|
|
95
|
-
str
|
|
96
|
-
The custom model name.
|
|
97
|
-
"""
|
|
98
|
-
# delegate error to get_custom_model
|
|
99
|
-
model = get_custom_model(value)
|
|
100
|
-
|
|
101
|
-
# check if it is a torch Module subclass
|
|
102
|
-
if not issubclass(model, Module):
|
|
103
|
-
raise ValueError(
|
|
104
|
-
f'Retrieved class {model} with name "{value}" is not a '
|
|
105
|
-
f"torch.nn.Module subclass."
|
|
106
|
-
)
|
|
107
|
-
|
|
108
|
-
return value
|
|
109
|
-
|
|
110
|
-
@model_validator(mode="after")
|
|
111
|
-
def check_parameters(self: Self) -> Self:
|
|
112
|
-
"""Validate model by instantiating the model with the parameters.
|
|
113
|
-
|
|
114
|
-
Returns
|
|
115
|
-
-------
|
|
116
|
-
Self
|
|
117
|
-
The validated model.
|
|
118
|
-
"""
|
|
119
|
-
# instantiate model
|
|
120
|
-
try:
|
|
121
|
-
get_custom_model(self.name)(**self.model_dump())
|
|
122
|
-
except Exception as e:
|
|
123
|
-
raise ValueError(
|
|
124
|
-
f"while passing parameters to the model {e}. Verify that all "
|
|
125
|
-
f"mandatory parameters are provided, and that either the {e} accepts "
|
|
126
|
-
f"*args and **kwargs in its __init__() method, or that no additional"
|
|
127
|
-
f"parameter is provided. Trace: "
|
|
128
|
-
f"filename: {inspect.trace()[-1].filename}, function: "
|
|
129
|
-
f"{inspect.trace()[-1].function}, line: {inspect.trace()[-1].lineno}"
|
|
130
|
-
) from None
|
|
131
|
-
|
|
132
|
-
return self
|
|
133
|
-
|
|
134
|
-
def __str__(self) -> str:
|
|
135
|
-
"""Pretty string representing the configuration.
|
|
136
|
-
|
|
137
|
-
Returns
|
|
138
|
-
-------
|
|
139
|
-
str
|
|
140
|
-
Pretty string.
|
|
141
|
-
"""
|
|
142
|
-
return pformat(self.model_dump())
|
|
143
|
-
|
|
144
|
-
def model_dump(self, **kwargs: Any) -> dict[str, Any]:
|
|
145
|
-
"""Dump the model configuration.
|
|
146
|
-
|
|
147
|
-
Parameters
|
|
148
|
-
----------
|
|
149
|
-
**kwargs : Any
|
|
150
|
-
Additional keyword arguments from Pydantic BaseModel model_dump method.
|
|
151
|
-
|
|
152
|
-
Returns
|
|
153
|
-
-------
|
|
154
|
-
dict[str, Any]
|
|
155
|
-
Model configuration.
|
|
156
|
-
"""
|
|
157
|
-
model_dict = super().model_dump()
|
|
158
|
-
|
|
159
|
-
# remove the name key
|
|
160
|
-
model_dict.pop("name")
|
|
161
|
-
|
|
162
|
-
return model_dict
|
|
@@ -1,103 +0,0 @@
|
|
|
1
|
-
"""Custom model registration utilities."""
|
|
2
|
-
|
|
3
|
-
from typing import Callable
|
|
4
|
-
|
|
5
|
-
from torch.nn import Module
|
|
6
|
-
|
|
7
|
-
CUSTOM_MODELS = {} # dictionary of custom models {"name": __class__}
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
def register_model(name: str) -> Callable:
|
|
11
|
-
"""Decorator used to register a torch.nn.Module class with a given `name`.
|
|
12
|
-
|
|
13
|
-
Parameters
|
|
14
|
-
----------
|
|
15
|
-
name : str
|
|
16
|
-
Name of the model.
|
|
17
|
-
|
|
18
|
-
Returns
|
|
19
|
-
-------
|
|
20
|
-
Callable
|
|
21
|
-
Function allowing to instantiate the wrapped Module class.
|
|
22
|
-
|
|
23
|
-
Raises
|
|
24
|
-
------
|
|
25
|
-
ValueError
|
|
26
|
-
If a model is already registered with that name.
|
|
27
|
-
|
|
28
|
-
Examples
|
|
29
|
-
--------
|
|
30
|
-
```python
|
|
31
|
-
@register_model(name="linear")
|
|
32
|
-
class LinearModel(nn.Module):
|
|
33
|
-
def __init__(self, in_features, out_features):
|
|
34
|
-
super().__init__()
|
|
35
|
-
|
|
36
|
-
self.weight = nn.Parameter(ones(in_features, out_features))
|
|
37
|
-
self.bias = nn.Parameter(ones(out_features))
|
|
38
|
-
|
|
39
|
-
def forward(self, input):
|
|
40
|
-
return (input @ self.weight) + self.bias
|
|
41
|
-
```
|
|
42
|
-
"""
|
|
43
|
-
if name is None or name == "":
|
|
44
|
-
raise ValueError("Model name cannot be empty.")
|
|
45
|
-
|
|
46
|
-
if name in CUSTOM_MODELS:
|
|
47
|
-
raise ValueError(
|
|
48
|
-
f"Model {name} already exists. Choose a different name or run "
|
|
49
|
-
f"`clear_custom_models()` to empty the registry."
|
|
50
|
-
)
|
|
51
|
-
|
|
52
|
-
def add_custom_model(model: Module) -> Module:
|
|
53
|
-
"""Add a custom model to the registry and return it.
|
|
54
|
-
|
|
55
|
-
Parameters
|
|
56
|
-
----------
|
|
57
|
-
model : Module
|
|
58
|
-
Module class to register.
|
|
59
|
-
|
|
60
|
-
Returns
|
|
61
|
-
-------
|
|
62
|
-
Module
|
|
63
|
-
The registered model.
|
|
64
|
-
"""
|
|
65
|
-
# add model to the registry
|
|
66
|
-
CUSTOM_MODELS[name] = model
|
|
67
|
-
|
|
68
|
-
return model
|
|
69
|
-
|
|
70
|
-
return add_custom_model
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
def get_custom_model(name: str) -> Module:
|
|
74
|
-
"""Get the custom model corresponding to `name` from the registry.
|
|
75
|
-
|
|
76
|
-
Parameters
|
|
77
|
-
----------
|
|
78
|
-
name : str
|
|
79
|
-
Name of the model to retrieve.
|
|
80
|
-
|
|
81
|
-
Returns
|
|
82
|
-
-------
|
|
83
|
-
Module
|
|
84
|
-
The requested model.
|
|
85
|
-
|
|
86
|
-
Raises
|
|
87
|
-
------
|
|
88
|
-
ValueError
|
|
89
|
-
If the model is not registered.
|
|
90
|
-
"""
|
|
91
|
-
if name not in CUSTOM_MODELS:
|
|
92
|
-
raise ValueError(
|
|
93
|
-
f"Model {name} is unknown. Have you registered it using "
|
|
94
|
-
f'@register_model("{name}") as decorator?'
|
|
95
|
-
)
|
|
96
|
-
|
|
97
|
-
return CUSTOM_MODELS[name]
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
def clear_custom_models() -> None:
|
|
101
|
-
"""Clear the custom models registry."""
|
|
102
|
-
# clear dictionary
|
|
103
|
-
CUSTOM_MODELS.clear()
|