spacr 0.2.81__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/__init__.py +2 -1
- spacr/core.py +107 -12
- spacr/gui.py +3 -2
- spacr/gui_core.py +8 -4
- spacr/gui_utils.py +4 -1
- spacr/io.py +13 -13
- spacr/measure.py +4 -4
- spacr/mediar.py +364 -0
- spacr/plot.py +5 -2
- spacr/resources/MEDIAR/.git +1 -0
- spacr/resources/MEDIAR/.gitignore +18 -0
- spacr/resources/MEDIAR/LICENSE +21 -0
- spacr/resources/MEDIAR/README.md +189 -0
- spacr/resources/MEDIAR/SetupDict.py +39 -0
- spacr/resources/MEDIAR/config/baseline.json +60 -0
- spacr/resources/MEDIAR/config/mediar_example.json +72 -0
- spacr/resources/MEDIAR/config/pred/pred_mediar.json +17 -0
- spacr/resources/MEDIAR/config/step1_pretraining/phase1.json +55 -0
- spacr/resources/MEDIAR/config/step1_pretraining/phase2.json +58 -0
- spacr/resources/MEDIAR/config/step2_finetuning/finetuning1.json +66 -0
- spacr/resources/MEDIAR/config/step2_finetuning/finetuning2.json +66 -0
- spacr/resources/MEDIAR/config/step3_prediction/base_prediction.json +16 -0
- spacr/resources/MEDIAR/config/step3_prediction/ensemble_tta.json +23 -0
- spacr/resources/MEDIAR/core/BasePredictor.py +120 -0
- spacr/resources/MEDIAR/core/BaseTrainer.py +240 -0
- spacr/resources/MEDIAR/core/Baseline/Predictor.py +59 -0
- spacr/resources/MEDIAR/core/Baseline/Trainer.py +113 -0
- spacr/resources/MEDIAR/core/Baseline/__init__.py +2 -0
- spacr/resources/MEDIAR/core/Baseline/utils.py +80 -0
- spacr/resources/MEDIAR/core/MEDIAR/EnsemblePredictor.py +105 -0
- spacr/resources/MEDIAR/core/MEDIAR/Predictor.py +234 -0
- spacr/resources/MEDIAR/core/MEDIAR/Trainer.py +172 -0
- spacr/resources/MEDIAR/core/MEDIAR/__init__.py +3 -0
- spacr/resources/MEDIAR/core/MEDIAR/utils.py +429 -0
- spacr/resources/MEDIAR/core/__init__.py +2 -0
- spacr/resources/MEDIAR/core/utils.py +40 -0
- spacr/resources/MEDIAR/evaluate.py +71 -0
- spacr/resources/MEDIAR/generate_mapping.py +121 -0
- spacr/resources/MEDIAR/image/examples/img1.tiff +0 -0
- spacr/resources/MEDIAR/image/examples/img2.tif +0 -0
- spacr/resources/MEDIAR/image/failure_cases.png +0 -0
- spacr/resources/MEDIAR/image/mediar_framework.png +0 -0
- spacr/resources/MEDIAR/image/mediar_model.PNG +0 -0
- spacr/resources/MEDIAR/image/mediar_results.png +0 -0
- spacr/resources/MEDIAR/main.py +125 -0
- spacr/resources/MEDIAR/predict.py +70 -0
- spacr/resources/MEDIAR/requirements.txt +14 -0
- spacr/resources/MEDIAR/train_tools/__init__.py +3 -0
- spacr/resources/MEDIAR/train_tools/data_utils/__init__.py +1 -0
- spacr/resources/MEDIAR/train_tools/data_utils/custom/CellAware.py +88 -0
- spacr/resources/MEDIAR/train_tools/data_utils/custom/LoadImage.py +161 -0
- spacr/resources/MEDIAR/train_tools/data_utils/custom/NormalizeImage.py +77 -0
- spacr/resources/MEDIAR/train_tools/data_utils/custom/__init__.py +3 -0
- spacr/resources/MEDIAR/train_tools/data_utils/custom/modalities.pkl +0 -0
- spacr/resources/MEDIAR/train_tools/data_utils/datasetter.py +208 -0
- spacr/resources/MEDIAR/train_tools/data_utils/transforms.py +148 -0
- spacr/resources/MEDIAR/train_tools/data_utils/utils.py +84 -0
- spacr/resources/MEDIAR/train_tools/measures.py +200 -0
- spacr/resources/MEDIAR/train_tools/models/MEDIARFormer.py +102 -0
- spacr/resources/MEDIAR/train_tools/models/__init__.py +1 -0
- spacr/resources/MEDIAR/train_tools/utils.py +70 -0
- spacr/resources/MEDIAR_weights/.DS_Store +0 -0
- spacr/resources/icons/.DS_Store +0 -0
- spacr/resources/icons/plaque.png +0 -0
- spacr/resources/images/plate1_E01_T0001F001L01A01Z01C02.tif +0 -0
- spacr/resources/images/plate1_E01_T0001F001L01A02Z01C01.tif +0 -0
- spacr/resources/images/plate1_E01_T0001F001L01A03Z01C03.tif +0 -0
- spacr/settings.py +3 -1
- spacr/utils.py +15 -13
- {spacr-0.2.81.dist-info → spacr-0.3.1.dist-info}/METADATA +9 -1
- {spacr-0.2.81.dist-info → spacr-0.3.1.dist-info}/RECORD +75 -16
- {spacr-0.2.81.dist-info → spacr-0.3.1.dist-info}/LICENSE +0 -0
- {spacr-0.2.81.dist-info → spacr-0.3.1.dist-info}/WHEEL +0 -0
- {spacr-0.2.81.dist-info → spacr-0.3.1.dist-info}/entry_points.txt +0 -0
- {spacr-0.2.81.dist-info → spacr-0.3.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,60 @@
|
|
1
|
+
{
|
2
|
+
"data_setups":{
|
3
|
+
"labeled":{
|
4
|
+
"root": "/home/gihun/data/CellSeg/",
|
5
|
+
"mapping_file": "./train_tools/data_utils/mapping_labeled.json",
|
6
|
+
"tuning_mapping_file": "/home/gihun/CellSeg/train_tools/data_utils/mapping_tuning.json",
|
7
|
+
"batch_size": 8,
|
8
|
+
"valid_portion": 0.1
|
9
|
+
},
|
10
|
+
"unlabeled":{
|
11
|
+
"enabled": false
|
12
|
+
},
|
13
|
+
"public":{
|
14
|
+
"enabled": false
|
15
|
+
}
|
16
|
+
},
|
17
|
+
"train_setups":{
|
18
|
+
"model":{
|
19
|
+
"name": "swinunetr",
|
20
|
+
"params": {
|
21
|
+
"img_size": 512,
|
22
|
+
"in_channels": 3,
|
23
|
+
"out_channels": 3,
|
24
|
+
"spatial_dims": 2
|
25
|
+
},
|
26
|
+
"pretrained":{
|
27
|
+
"enabled": false
|
28
|
+
}
|
29
|
+
},
|
30
|
+
"trainer": {
|
31
|
+
"name": "baseline",
|
32
|
+
"params": {
|
33
|
+
"num_epochs": 200,
|
34
|
+
"valid_frequency": 1,
|
35
|
+
"device": "cuda:0",
|
36
|
+
"algo_params": {}
|
37
|
+
}
|
38
|
+
},
|
39
|
+
"optimizer":{
|
40
|
+
"name": "adamw",
|
41
|
+
"params": {"lr": 5e-5}
|
42
|
+
},
|
43
|
+
"scheduler":{
|
44
|
+
"enabled": false
|
45
|
+
},
|
46
|
+
"seed": 19940817
|
47
|
+
},
|
48
|
+
"pred_setups":{
|
49
|
+
"input_path":"/home/gihun/data/CellSeg/Official/TuningSet",
|
50
|
+
"output_path": "./results/baseline",
|
51
|
+
"make_submission": true,
|
52
|
+
"exp_name": "baseline",
|
53
|
+
"algo_params": {}
|
54
|
+
},
|
55
|
+
"wandb_setups":{
|
56
|
+
"project": "CellSeg",
|
57
|
+
"group": "Baseline",
|
58
|
+
"name": "baseline"
|
59
|
+
}
|
60
|
+
}
|
@@ -0,0 +1,72 @@
|
|
1
|
+
{
|
2
|
+
"data_setups":{
|
3
|
+
"labeled":{
|
4
|
+
"root": "/home/gihun/data/CellSeg/",
|
5
|
+
"mapping_file": "./train_tools/data_utils/mapping_labeled.json",
|
6
|
+
"amplified": false,
|
7
|
+
"batch_size": 8,
|
8
|
+
"valid_portion": 0.1
|
9
|
+
},
|
10
|
+
"public":{
|
11
|
+
"enabled": true,
|
12
|
+
"params":{
|
13
|
+
"root": "/home/gihun/data/CellSeg/",
|
14
|
+
"mapping_file": "./train_tools/data_utils/mapping_public.json",
|
15
|
+
"batch_size": 1
|
16
|
+
}
|
17
|
+
},
|
18
|
+
"unlabeled":{
|
19
|
+
"enabled": false
|
20
|
+
}
|
21
|
+
},
|
22
|
+
"train_setups":{
|
23
|
+
"model":{
|
24
|
+
"name": "mediar-former",
|
25
|
+
"params": {
|
26
|
+
"encoder_name": "mit_b5",
|
27
|
+
"encoder_weights": "imagenet",
|
28
|
+
"decoder_channels": [1024, 512, 256, 128, 64],
|
29
|
+
"decoder_pab_channels": 256,
|
30
|
+
"in_channels": 3,
|
31
|
+
"classes": 3
|
32
|
+
},
|
33
|
+
"pretrained":{
|
34
|
+
"enabled": false,
|
35
|
+
"weights": "./weights/pretrained/phase2.pth",
|
36
|
+
"strict": false
|
37
|
+
}
|
38
|
+
},
|
39
|
+
"trainer": {
|
40
|
+
"name": "mediar",
|
41
|
+
"params": {
|
42
|
+
"num_epochs": 200,
|
43
|
+
"valid_frequency": 1,
|
44
|
+
"device": "cuda:0",
|
45
|
+
"amp": true,
|
46
|
+
"algo_params": {"with_public": false}
|
47
|
+
}
|
48
|
+
},
|
49
|
+
"optimizer":{
|
50
|
+
"name": "adamw",
|
51
|
+
"params": {"lr": 5e-5}
|
52
|
+
},
|
53
|
+
"scheduler":{
|
54
|
+
"enabled": true,
|
55
|
+
"name": "cosine",
|
56
|
+
"params": {"T_max": 100, "eta_min": 1e-7}
|
57
|
+
},
|
58
|
+
"seed": 19940817
|
59
|
+
},
|
60
|
+
"pred_setups":{
|
61
|
+
"input_path":"/home/gihun/data/CellSeg/Official/TuningSet",
|
62
|
+
"output_path": "./mediar_example",
|
63
|
+
"make_submission": true,
|
64
|
+
"exp_name": "mediar_example",
|
65
|
+
"algo_params": {"use_tta": false}
|
66
|
+
},
|
67
|
+
"wandb_setups":{
|
68
|
+
"project": "CellSeg",
|
69
|
+
"group": "MEDIAR",
|
70
|
+
"name": "mediar_example"
|
71
|
+
}
|
72
|
+
}
|
@@ -0,0 +1,17 @@
|
|
1
|
+
{
|
2
|
+
"pred_setups":{
|
3
|
+
"name": "medair",
|
4
|
+
"input_path":"input_path",
|
5
|
+
"output_path": "./test",
|
6
|
+
"make_submission": true,
|
7
|
+
"model_path": "model_path",
|
8
|
+
"device": "cuda:0",
|
9
|
+
"model":{
|
10
|
+
"name": "mediar-former",
|
11
|
+
"params": {},
|
12
|
+
"pretrained":{
|
13
|
+
"enabled": false
|
14
|
+
}
|
15
|
+
}
|
16
|
+
}
|
17
|
+
}
|
@@ -0,0 +1,55 @@
|
|
1
|
+
{
|
2
|
+
"data_setups":{
|
3
|
+
"labeled":{
|
4
|
+
"root": "/home/gihun/MEDIAR/",
|
5
|
+
"mapping_file": "./train_tools/data_utils/mapping_public.json",
|
6
|
+
"mapping_file_tuning": "/home/gihun/MEDIAR/train_tools/data_utils/mapping_tuning.json",
|
7
|
+
"batch_size": 9,
|
8
|
+
"valid_portion": 0
|
9
|
+
},
|
10
|
+
"public":{
|
11
|
+
"enabled": false,
|
12
|
+
"params":{}
|
13
|
+
}
|
14
|
+
},
|
15
|
+
"train_setups":{
|
16
|
+
"model":{
|
17
|
+
"name": "mediar-former",
|
18
|
+
"params": {},
|
19
|
+
"pretrained":{
|
20
|
+
"enabled": false
|
21
|
+
}
|
22
|
+
},
|
23
|
+
"trainer": {
|
24
|
+
"name": "mediar",
|
25
|
+
"params": {
|
26
|
+
"num_epochs": 80,
|
27
|
+
"valid_frequency": 10,
|
28
|
+
"device": "cuda:0",
|
29
|
+
"amp": true,
|
30
|
+
"algo_params": {"with_public": false}
|
31
|
+
}
|
32
|
+
},
|
33
|
+
"optimizer":{
|
34
|
+
"name": "adamw",
|
35
|
+
"ft_rate": 1.0,
|
36
|
+
"params": {"lr": 5e-5}
|
37
|
+
},
|
38
|
+
"scheduler":{
|
39
|
+
"enabled": true,
|
40
|
+
"name": "cosine",
|
41
|
+
"params": {"T_max": 80, "eta_min": 1e-6}
|
42
|
+
},
|
43
|
+
"seed": 19940817
|
44
|
+
},
|
45
|
+
"pred_setups":{
|
46
|
+
"input_path":"/home/gihun/MEDIAR/data/Official/Tuning/images",
|
47
|
+
"output_path": "./mediar_pretrained_phase1",
|
48
|
+
"make_submission": false
|
49
|
+
},
|
50
|
+
"wandb_setups":{
|
51
|
+
"project": "CellSeg",
|
52
|
+
"group": "Pretraining",
|
53
|
+
"name": "phase1"
|
54
|
+
}
|
55
|
+
}
|
@@ -0,0 +1,58 @@
|
|
1
|
+
{
|
2
|
+
"data_setups":{
|
3
|
+
"labeled":{
|
4
|
+
"root": "/home/gihun/MEDIAR/",
|
5
|
+
"mapping_file": "./train_tools/data_utils/mapping_labeled.json",
|
6
|
+
"mapping_file_tuning": "/home/gihun/MEDIAR/train_tools/data_utils/mapping_tuning.json",
|
7
|
+
"join_mapping_file": "./train_tools/data_utils/mapping_public.json",
|
8
|
+
"batch_size": 9,
|
9
|
+
"valid_portion": 0
|
10
|
+
},
|
11
|
+
"unlabeled":{
|
12
|
+
"enabled": false
|
13
|
+
},
|
14
|
+
"public":{
|
15
|
+
"enabled": false
|
16
|
+
}
|
17
|
+
},
|
18
|
+
"train_setups":{
|
19
|
+
"model":{
|
20
|
+
"name": "mediar-former",
|
21
|
+
"params": {},
|
22
|
+
"pretrained":{
|
23
|
+
"enabled": false
|
24
|
+
}
|
25
|
+
},
|
26
|
+
"trainer": {
|
27
|
+
"name": "mediar",
|
28
|
+
"params": {
|
29
|
+
"num_epochs": 60,
|
30
|
+
"valid_frequency": 10,
|
31
|
+
"device": "cuda:0",
|
32
|
+
"amp": true,
|
33
|
+
"algo_params": {"with_public": false}
|
34
|
+
}
|
35
|
+
},
|
36
|
+
"optimizer":{
|
37
|
+
"name": "adamw",
|
38
|
+
"ft_rate": 1.0,
|
39
|
+
"params": {"lr": 5e-5}
|
40
|
+
},
|
41
|
+
"scheduler":{
|
42
|
+
"enabled": true,
|
43
|
+
"name": "cosine",
|
44
|
+
"params": {"T_max": 60, "eta_min": 1e-6}
|
45
|
+
},
|
46
|
+
"seed": 19940817
|
47
|
+
},
|
48
|
+
"pred_setups":{
|
49
|
+
"input_path":"/home/gihun/MEDIAR/data/Official/Tuning/images",
|
50
|
+
"output_path": "./mediar_pretrain_phase2",
|
51
|
+
"make_submission": false
|
52
|
+
},
|
53
|
+
"wandb_setups":{
|
54
|
+
"project": "CellSeg",
|
55
|
+
"group": "Pretraining",
|
56
|
+
"name": "phase2"
|
57
|
+
}
|
58
|
+
}
|
@@ -0,0 +1,66 @@
|
|
1
|
+
{
|
2
|
+
"data_setups":{
|
3
|
+
"labeled":{
|
4
|
+
"root": "/home/gihun/MEDIAR/",
|
5
|
+
"mapping_file": "./train_tools/data_utils/mapping_labeled.json",
|
6
|
+
"mapping_file_tuning": "/home/gihun/MEDIAR/train_tools/data_utils/mapping_tuning.json",
|
7
|
+
"amplified": true,
|
8
|
+
"batch_size": 8,
|
9
|
+
"valid_portion": 0.0
|
10
|
+
},
|
11
|
+
"public":{
|
12
|
+
"enabled": false,
|
13
|
+
"params":{
|
14
|
+
"root": "/home/gihun/MEDIAR/",
|
15
|
+
"mapping_file": "./train_tools/data_utils/mapping_public.json",
|
16
|
+
"batch_size": 1
|
17
|
+
}
|
18
|
+
},
|
19
|
+
"unlabeled":{
|
20
|
+
"enabled": false
|
21
|
+
}
|
22
|
+
},
|
23
|
+
"train_setups":{
|
24
|
+
"model":{
|
25
|
+
"name": "mediar-former",
|
26
|
+
"params": {},
|
27
|
+
"pretrained":{
|
28
|
+
"enabled": true,
|
29
|
+
"weights": "./weights/pretrained/phase1.pth",
|
30
|
+
"strict": false
|
31
|
+
}
|
32
|
+
},
|
33
|
+
"trainer": {
|
34
|
+
"name": "mediar",
|
35
|
+
"params": {
|
36
|
+
"num_epochs": 200,
|
37
|
+
"valid_frequency": 1,
|
38
|
+
"device": "cuda:7",
|
39
|
+
"amp": true,
|
40
|
+
"algo_params": {"with_public": false}
|
41
|
+
}
|
42
|
+
},
|
43
|
+
"optimizer":{
|
44
|
+
"name": "adamw",
|
45
|
+
"params": {"lr": 2e-5}
|
46
|
+
},
|
47
|
+
"scheduler":{
|
48
|
+
"enabled": true,
|
49
|
+
"name": "cosine",
|
50
|
+
"params": {"T_max": 100, "eta_min": 1e-7}
|
51
|
+
},
|
52
|
+
"seed": 19940817
|
53
|
+
},
|
54
|
+
"pred_setups":{
|
55
|
+
"input_path":"/home/gihun/MEDIAR/data/Official/Tuning/images",
|
56
|
+
"output_path": "./results/",
|
57
|
+
"make_submission": true,
|
58
|
+
"exp_name": "mediar_from_phase1",
|
59
|
+
"algo_params": {"use_tta": false}
|
60
|
+
},
|
61
|
+
"wandb_setups":{
|
62
|
+
"project": "CellSeg",
|
63
|
+
"group": "Fine-tuning",
|
64
|
+
"name": "from_phase1"
|
65
|
+
}
|
66
|
+
}
|
@@ -0,0 +1,66 @@
|
|
1
|
+
{
|
2
|
+
"data_setups":{
|
3
|
+
"labeled":{
|
4
|
+
"root": "/home/gihun/MEDIAR/",
|
5
|
+
"mapping_file": "./train_tools/data_utils/mapping_labeled.json",
|
6
|
+
"mapping_file_tuning": "/home/gihun/MEDIAR/train_tools/data_utils/mapping_tuning.json",
|
7
|
+
"amplified": true,
|
8
|
+
"batch_size": 8,
|
9
|
+
"valid_portion": 0.0
|
10
|
+
},
|
11
|
+
"public":{
|
12
|
+
"enabled": true,
|
13
|
+
"params":{
|
14
|
+
"root": "/home/gihun/MEDIAR/",
|
15
|
+
"mapping_file": "./train_tools/data_utils/mapping_public.json",
|
16
|
+
"batch_size": 1
|
17
|
+
}
|
18
|
+
},
|
19
|
+
"unlabeled":{
|
20
|
+
"enabled": false
|
21
|
+
}
|
22
|
+
},
|
23
|
+
"train_setups":{
|
24
|
+
"model":{
|
25
|
+
"name": "mediar-former",
|
26
|
+
"params": {},
|
27
|
+
"pretrained":{
|
28
|
+
"enabled": true,
|
29
|
+
"weights": "./weights/pretrained/phase2.pth",
|
30
|
+
"strict": false
|
31
|
+
}
|
32
|
+
},
|
33
|
+
"trainer": {
|
34
|
+
"name": "mediar",
|
35
|
+
"params": {
|
36
|
+
"num_epochs": 50,
|
37
|
+
"valid_frequency": 1,
|
38
|
+
"device": "cuda:0",
|
39
|
+
"amp": true,
|
40
|
+
"algo_params": {"with_public": true}
|
41
|
+
}
|
42
|
+
},
|
43
|
+
"optimizer":{
|
44
|
+
"name": "adamw",
|
45
|
+
"params": {"lr": 2e-5}
|
46
|
+
},
|
47
|
+
"scheduler":{
|
48
|
+
"enabled": true,
|
49
|
+
"name": "cosine",
|
50
|
+
"params": {"T_max": 100, "eta_min": 1e-7}
|
51
|
+
},
|
52
|
+
"seed": 19940817
|
53
|
+
},
|
54
|
+
"pred_setups":{
|
55
|
+
"input_path":"/home/gihun/MEDIAR/data/Official/Tuning/images",
|
56
|
+
"output_path": "./results/from_phase2",
|
57
|
+
"make_submission": true,
|
58
|
+
"exp_name": "mediar_from_phase2",
|
59
|
+
"algo_params": {"use_tta": false}
|
60
|
+
},
|
61
|
+
"wandb_setups":{
|
62
|
+
"project": "CellSeg",
|
63
|
+
"group": "Fine-tuning",
|
64
|
+
"name": "from_phase2"
|
65
|
+
}
|
66
|
+
}
|
@@ -0,0 +1,16 @@
|
|
1
|
+
{
|
2
|
+
"pred_setups":{
|
3
|
+
"name": "mediar",
|
4
|
+
"input_path":"/home/gihun/MEDIAR/data/Official/Tuning/images",
|
5
|
+
"output_path": "./results/mediar_base_prediction",
|
6
|
+
"make_submission": true,
|
7
|
+
"model_path": "./weights/finetuned/from_phase1.pth",
|
8
|
+
"device": "cuda:7",
|
9
|
+
"model":{
|
10
|
+
"name": "mediar-former",
|
11
|
+
"params": {}
|
12
|
+
},
|
13
|
+
"exp_name": "mediar_p1_base",
|
14
|
+
"algo_params": {"use_tta": false}
|
15
|
+
}
|
16
|
+
}
|
@@ -0,0 +1,23 @@
|
|
1
|
+
{
|
2
|
+
"pred_setups":{
|
3
|
+
"name": "ensemble_mediar",
|
4
|
+
"input_path":"/home/gihun/MEDIAR/data/Official/Tuning/images",
|
5
|
+
"output_path": "./results/mediar_ensemble_tta",
|
6
|
+
"make_submission": true,
|
7
|
+
"model_path1": "./weights/finetuned/from_phase1.pth",
|
8
|
+
"model_path2": "./weights/finetuned/from_phase2.pth",
|
9
|
+
"device": "cuda:0",
|
10
|
+
"model":{
|
11
|
+
"name": "mediar-former",
|
12
|
+
"params": {
|
13
|
+
"encoder_name":"mit_b5",
|
14
|
+
"decoder_channels": [1024, 512, 256, 128, 64],
|
15
|
+
"decoder_pab_channels": 256,
|
16
|
+
"in_channels":3,
|
17
|
+
"classes":3
|
18
|
+
}
|
19
|
+
},
|
20
|
+
"exp_name": "mediar_ensemble_tta",
|
21
|
+
"algo_params": {"use_tta": true}
|
22
|
+
}
|
23
|
+
}
|
@@ -0,0 +1,120 @@
|
|
1
|
+
import torch
|
2
|
+
import numpy as np
|
3
|
+
import time, os
|
4
|
+
import tifffile as tif
|
5
|
+
|
6
|
+
from datetime import datetime
|
7
|
+
from zipfile import ZipFile
|
8
|
+
from pytz import timezone
|
9
|
+
|
10
|
+
from train_tools.data_utils.transforms import get_pred_transforms
|
11
|
+
|
12
|
+
|
13
|
+
class BasePredictor:
|
14
|
+
def __init__(
|
15
|
+
self,
|
16
|
+
model,
|
17
|
+
device,
|
18
|
+
input_path,
|
19
|
+
output_path,
|
20
|
+
make_submission=False,
|
21
|
+
exp_name=None,
|
22
|
+
algo_params=None,
|
23
|
+
):
|
24
|
+
self.model = model
|
25
|
+
self.device = device
|
26
|
+
self.input_path = input_path
|
27
|
+
self.output_path = output_path
|
28
|
+
self.make_submission = make_submission
|
29
|
+
self.exp_name = exp_name
|
30
|
+
|
31
|
+
# Assign algoritm-specific arguments
|
32
|
+
if algo_params:
|
33
|
+
self.__dict__.update((k, v) for k, v in algo_params.items())
|
34
|
+
|
35
|
+
# Prepare inference environments
|
36
|
+
self._setups()
|
37
|
+
|
38
|
+
@torch.no_grad()
|
39
|
+
def conduct_prediction(self):
|
40
|
+
self.model.to(self.device)
|
41
|
+
self.model.eval()
|
42
|
+
total_time = 0
|
43
|
+
total_times = []
|
44
|
+
|
45
|
+
for img_name in self.img_names:
|
46
|
+
img_data = self._get_img_data(img_name)
|
47
|
+
img_data = img_data.to(self.device)
|
48
|
+
|
49
|
+
start = time.time()
|
50
|
+
|
51
|
+
pred_mask = self._inference(img_data)
|
52
|
+
pred_mask = self._post_process(pred_mask.squeeze(0).cpu().numpy())
|
53
|
+
self.write_pred_mask(
|
54
|
+
pred_mask, self.output_path, img_name, self.make_submission
|
55
|
+
)
|
56
|
+
|
57
|
+
end = time.time()
|
58
|
+
|
59
|
+
time_cost = end - start
|
60
|
+
total_times.append(time_cost)
|
61
|
+
total_time += time_cost
|
62
|
+
print(
|
63
|
+
f"Prediction finished: {img_name}; img size = {img_data.shape}; costing: {time_cost:.2f}s"
|
64
|
+
)
|
65
|
+
|
66
|
+
print(f"\n Total Time Cost: {total_time:.2f}s")
|
67
|
+
|
68
|
+
if self.make_submission:
|
69
|
+
fname = "%s.zip" % self.exp_name
|
70
|
+
|
71
|
+
os.makedirs("./submissions", exist_ok=True)
|
72
|
+
submission_path = os.path.join("./submissions", fname)
|
73
|
+
|
74
|
+
with ZipFile(submission_path, "w") as zipObj2:
|
75
|
+
pred_names = sorted(os.listdir(self.output_path))
|
76
|
+
for pred_name in pred_names:
|
77
|
+
pred_path = os.path.join(self.output_path, pred_name)
|
78
|
+
zipObj2.write(pred_path)
|
79
|
+
|
80
|
+
print("\n>>>>> Submission file is saved at: %s\n" % submission_path)
|
81
|
+
|
82
|
+
return time_cost
|
83
|
+
|
84
|
+
def write_pred_mask(self, pred_mask, output_dir, image_name, submission=False):
|
85
|
+
|
86
|
+
# All images should contain at least 5 cells
|
87
|
+
if submission:
|
88
|
+
if not (np.max(pred_mask) > 5):
|
89
|
+
print("[!Caution] Only %d Cells Detected!!!\n" % np.max(pred_mask))
|
90
|
+
|
91
|
+
file_name = image_name.split(".")[0]
|
92
|
+
file_name = file_name + "_label.tiff"
|
93
|
+
file_path = os.path.join(output_dir, file_name)
|
94
|
+
|
95
|
+
tif.imwrite(file_path, pred_mask, compression="zlib")
|
96
|
+
|
97
|
+
def _setups(self):
|
98
|
+
self.pred_transforms = get_pred_transforms()
|
99
|
+
os.makedirs(self.output_path, exist_ok=True)
|
100
|
+
|
101
|
+
now = datetime.now(timezone("Asia/Seoul"))
|
102
|
+
dt_string = now.strftime("%m%d_%H%M")
|
103
|
+
self.exp_name = (
|
104
|
+
self.exp_name + dt_string if self.exp_name is not None else dt_string
|
105
|
+
)
|
106
|
+
|
107
|
+
self.img_names = sorted(os.listdir(self.input_path))
|
108
|
+
|
109
|
+
def _get_img_data(self, img_name):
|
110
|
+
img_path = os.path.join(self.input_path, img_name)
|
111
|
+
img_data = self.pred_transforms(img_path)
|
112
|
+
img_data = img_data.unsqueeze(0)
|
113
|
+
|
114
|
+
return img_data
|
115
|
+
|
116
|
+
def _inference(self, img_data):
|
117
|
+
raise NotImplementedError
|
118
|
+
|
119
|
+
def _post_process(self, pred_mask):
|
120
|
+
raise NotImplementedError
|