spacr 0.2.68__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/__init__.py +2 -1
- spacr/core.py +107 -12
- spacr/gui.py +3 -2
- spacr/gui_core.py +160 -109
- spacr/gui_elements.py +190 -18
- spacr/gui_utils.py +4 -1
- spacr/io.py +1 -1
- spacr/measure.py +4 -4
- spacr/mediar.py +366 -0
- spacr/plot.py +4 -1
- spacr/resources/MEDIAR/.git +1 -0
- spacr/resources/MEDIAR/.gitignore +18 -0
- spacr/resources/MEDIAR/LICENSE +21 -0
- spacr/resources/MEDIAR/README.md +189 -0
- spacr/resources/MEDIAR/SetupDict.py +39 -0
- spacr/resources/MEDIAR/config/baseline.json +60 -0
- spacr/resources/MEDIAR/config/mediar_example.json +72 -0
- spacr/resources/MEDIAR/config/pred/pred_mediar.json +17 -0
- spacr/resources/MEDIAR/config/step1_pretraining/phase1.json +55 -0
- spacr/resources/MEDIAR/config/step1_pretraining/phase2.json +58 -0
- spacr/resources/MEDIAR/config/step2_finetuning/finetuning1.json +66 -0
- spacr/resources/MEDIAR/config/step2_finetuning/finetuning2.json +66 -0
- spacr/resources/MEDIAR/config/step3_prediction/base_prediction.json +16 -0
- spacr/resources/MEDIAR/config/step3_prediction/ensemble_tta.json +23 -0
- spacr/resources/MEDIAR/core/BasePredictor.py +120 -0
- spacr/resources/MEDIAR/core/BaseTrainer.py +240 -0
- spacr/resources/MEDIAR/core/Baseline/Predictor.py +59 -0
- spacr/resources/MEDIAR/core/Baseline/Trainer.py +113 -0
- spacr/resources/MEDIAR/core/Baseline/__init__.py +2 -0
- spacr/resources/MEDIAR/core/Baseline/utils.py +80 -0
- spacr/resources/MEDIAR/core/MEDIAR/EnsemblePredictor.py +105 -0
- spacr/resources/MEDIAR/core/MEDIAR/Predictor.py +234 -0
- spacr/resources/MEDIAR/core/MEDIAR/Trainer.py +172 -0
- spacr/resources/MEDIAR/core/MEDIAR/__init__.py +3 -0
- spacr/resources/MEDIAR/core/MEDIAR/utils.py +429 -0
- spacr/resources/MEDIAR/core/__init__.py +2 -0
- spacr/resources/MEDIAR/core/utils.py +40 -0
- spacr/resources/MEDIAR/evaluate.py +71 -0
- spacr/resources/MEDIAR/generate_mapping.py +121 -0
- spacr/resources/MEDIAR/image/examples/img1.tiff +0 -0
- spacr/resources/MEDIAR/image/examples/img2.tif +0 -0
- spacr/resources/MEDIAR/image/failure_cases.png +0 -0
- spacr/resources/MEDIAR/image/mediar_framework.png +0 -0
- spacr/resources/MEDIAR/image/mediar_model.PNG +0 -0
- spacr/resources/MEDIAR/image/mediar_results.png +0 -0
- spacr/resources/MEDIAR/main.py +125 -0
- spacr/resources/MEDIAR/predict.py +70 -0
- spacr/resources/MEDIAR/requirements.txt +14 -0
- spacr/resources/MEDIAR/train_tools/__init__.py +3 -0
- spacr/resources/MEDIAR/train_tools/data_utils/__init__.py +1 -0
- spacr/resources/MEDIAR/train_tools/data_utils/custom/CellAware.py +88 -0
- spacr/resources/MEDIAR/train_tools/data_utils/custom/LoadImage.py +161 -0
- spacr/resources/MEDIAR/train_tools/data_utils/custom/NormalizeImage.py +77 -0
- spacr/resources/MEDIAR/train_tools/data_utils/custom/__init__.py +3 -0
- spacr/resources/MEDIAR/train_tools/data_utils/custom/modalities.pkl +0 -0
- spacr/resources/MEDIAR/train_tools/data_utils/datasetter.py +208 -0
- spacr/resources/MEDIAR/train_tools/data_utils/transforms.py +148 -0
- spacr/resources/MEDIAR/train_tools/data_utils/utils.py +84 -0
- spacr/resources/MEDIAR/train_tools/measures.py +200 -0
- spacr/resources/MEDIAR/train_tools/models/MEDIARFormer.py +102 -0
- spacr/resources/MEDIAR/train_tools/models/__init__.py +1 -0
- spacr/resources/MEDIAR/train_tools/utils.py +70 -0
- spacr/resources/MEDIAR_weights/.DS_Store +0 -0
- spacr/resources/icons/.DS_Store +0 -0
- spacr/resources/icons/plaque.png +0 -0
- spacr/resources/images/plate1_E01_T0001F001L01A01Z01C02.tif +0 -0
- spacr/resources/images/plate1_E01_T0001F001L01A02Z01C01.tif +0 -0
- spacr/resources/images/plate1_E01_T0001F001L01A03Z01C03.tif +0 -0
- spacr/sequencing.py +234 -422
- spacr/settings.py +16 -10
- spacr/utils.py +14 -11
- {spacr-0.2.68.dist-info → spacr-0.3.0.dist-info}/METADATA +10 -2
- {spacr-0.2.68.dist-info → spacr-0.3.0.dist-info}/RECORD +77 -18
- {spacr-0.2.68.dist-info → spacr-0.3.0.dist-info}/LICENSE +0 -0
- {spacr-0.2.68.dist-info → spacr-0.3.0.dist-info}/WHEEL +0 -0
- {spacr-0.2.68.dist-info → spacr-0.3.0.dist-info}/entry_points.txt +0 -0
- {spacr-0.2.68.dist-info → spacr-0.3.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,102 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
|
4
|
+
from segmentation_models_pytorch import MAnet
|
5
|
+
from segmentation_models_pytorch.base.modules import Activation
|
6
|
+
|
7
|
+
__all__ = ["MEDIARFormer"]
|
8
|
+
|
9
|
+
|
10
|
+
class MEDIARFormer(MAnet):
|
11
|
+
"""MEDIAR-Former Model"""
|
12
|
+
|
13
|
+
def __init__(
|
14
|
+
self,
|
15
|
+
encoder_name="mit_b5", # Default encoder
|
16
|
+
encoder_weights="imagenet", # Pre-trained weights
|
17
|
+
decoder_channels=(1024, 512, 256, 128, 64), # Decoder configuration
|
18
|
+
decoder_pab_channels=256, # Decoder Pyramid Attention Block channels
|
19
|
+
in_channels=3, # Number of input channels
|
20
|
+
classes=3, # Number of output classes
|
21
|
+
):
|
22
|
+
# Initialize the MAnet model with provided parameters
|
23
|
+
super().__init__(
|
24
|
+
encoder_name=encoder_name,
|
25
|
+
encoder_weights=encoder_weights,
|
26
|
+
decoder_channels=decoder_channels,
|
27
|
+
decoder_pab_channels=decoder_pab_channels,
|
28
|
+
in_channels=in_channels,
|
29
|
+
classes=classes,
|
30
|
+
)
|
31
|
+
|
32
|
+
# Remove the default segmentation head as it's not used in this architecture
|
33
|
+
self.segmentation_head = None
|
34
|
+
|
35
|
+
# Modify all activation functions in the encoder and decoder from ReLU to Mish
|
36
|
+
_convert_activations(self.encoder, nn.ReLU, nn.Mish(inplace=True))
|
37
|
+
_convert_activations(self.decoder, nn.ReLU, nn.Mish(inplace=True))
|
38
|
+
|
39
|
+
# Add custom segmentation heads for different segmentation tasks
|
40
|
+
self.cellprob_head = DeepSegmentationHead(
|
41
|
+
in_channels=decoder_channels[-1], out_channels=1
|
42
|
+
)
|
43
|
+
self.gradflow_head = DeepSegmentationHead(
|
44
|
+
in_channels=decoder_channels[-1], out_channels=2
|
45
|
+
)
|
46
|
+
|
47
|
+
def forward(self, x):
|
48
|
+
"""Forward pass through the network"""
|
49
|
+
# Ensure the input shape is correct
|
50
|
+
self.check_input_shape(x)
|
51
|
+
|
52
|
+
# Encode the input and then decode it
|
53
|
+
features = self.encoder(x)
|
54
|
+
decoder_output = self.decoder(*features)
|
55
|
+
|
56
|
+
# Generate masks for cell probability and gradient flows
|
57
|
+
cellprob_mask = self.cellprob_head(decoder_output)
|
58
|
+
gradflow_mask = self.gradflow_head(decoder_output)
|
59
|
+
|
60
|
+
# Concatenate the masks for output
|
61
|
+
masks = torch.cat([gradflow_mask, cellprob_mask], dim=1)
|
62
|
+
|
63
|
+
return masks
|
64
|
+
|
65
|
+
|
66
|
+
class DeepSegmentationHead(nn.Sequential):
|
67
|
+
"""Custom segmentation head for generating specific masks"""
|
68
|
+
|
69
|
+
def __init__(
|
70
|
+
self, in_channels, out_channels, kernel_size=3, activation=None, upsampling=1
|
71
|
+
):
|
72
|
+
# Define a sequence of layers for the segmentation head
|
73
|
+
layers = [
|
74
|
+
nn.Conv2d(
|
75
|
+
in_channels,
|
76
|
+
in_channels // 2,
|
77
|
+
kernel_size=kernel_size,
|
78
|
+
padding=kernel_size // 2,
|
79
|
+
),
|
80
|
+
nn.Mish(inplace=True),
|
81
|
+
nn.BatchNorm2d(in_channels // 2),
|
82
|
+
nn.Conv2d(
|
83
|
+
in_channels // 2,
|
84
|
+
out_channels,
|
85
|
+
kernel_size=kernel_size,
|
86
|
+
padding=kernel_size // 2,
|
87
|
+
),
|
88
|
+
nn.UpsamplingBilinear2d(scale_factor=upsampling)
|
89
|
+
if upsampling > 1
|
90
|
+
else nn.Identity(),
|
91
|
+
Activation(activation) if activation else nn.Identity(),
|
92
|
+
]
|
93
|
+
super().__init__(*layers)
|
94
|
+
|
95
|
+
|
96
|
+
def _convert_activations(module, from_activation, to_activation):
|
97
|
+
"""Recursively convert activation functions in a module"""
|
98
|
+
for name, child in module.named_children():
|
99
|
+
if isinstance(child, from_activation):
|
100
|
+
setattr(module, name, to_activation)
|
101
|
+
else:
|
102
|
+
_convert_activations(child, from_activation, to_activation)
|
@@ -0,0 +1 @@
|
|
1
|
+
from .MEDIARFormer import *
|
@@ -0,0 +1,70 @@
|
|
1
|
+
import torch
|
2
|
+
import numpy as np
|
3
|
+
import os, json, random
|
4
|
+
from pprint import pprint
|
5
|
+
|
6
|
+
__all__ = ["ConfLoader", "directory_setter", "random_seeder", "pprint_config"]
|
7
|
+
|
8
|
+
|
9
|
+
class ConfLoader:
|
10
|
+
"""
|
11
|
+
Load json config file using DictWithAttributeAccess object_hook.
|
12
|
+
ConfLoader(conf_name).opt attribute is the result of loading json config file.
|
13
|
+
"""
|
14
|
+
|
15
|
+
class DictWithAttributeAccess(dict):
|
16
|
+
"""
|
17
|
+
This inner class makes dict to be accessed same as class attribute.
|
18
|
+
For example, you can use opt.key instead of the opt['key'].
|
19
|
+
"""
|
20
|
+
|
21
|
+
def __getattr__(self, key):
|
22
|
+
return self[key]
|
23
|
+
|
24
|
+
def __setattr__(self, key, value):
|
25
|
+
self[key] = value
|
26
|
+
|
27
|
+
def __init__(self, conf_name):
|
28
|
+
self.conf_name = conf_name
|
29
|
+
self.opt = self.__get_opt()
|
30
|
+
|
31
|
+
def __load_conf(self):
|
32
|
+
with open(self.conf_name, "r") as conf:
|
33
|
+
opt = json.load(
|
34
|
+
conf, object_hook=lambda dict: self.DictWithAttributeAccess(dict)
|
35
|
+
)
|
36
|
+
return opt
|
37
|
+
|
38
|
+
def __get_opt(self):
|
39
|
+
opt = self.__load_conf()
|
40
|
+
opt = self.DictWithAttributeAccess(opt)
|
41
|
+
|
42
|
+
return opt
|
43
|
+
|
44
|
+
|
45
|
+
def directory_setter(path="./results", make_dir=False):
|
46
|
+
"""
|
47
|
+
Make dictionary if not exists.
|
48
|
+
"""
|
49
|
+
if not os.path.exists(path) and make_dir:
|
50
|
+
os.makedirs(path) # make dir if not exist
|
51
|
+
print("directory %s is created" % path)
|
52
|
+
|
53
|
+
if not os.path.isdir(path):
|
54
|
+
raise NotADirectoryError(
|
55
|
+
"%s is not valid. set make_dir=True to make dir." % path
|
56
|
+
)
|
57
|
+
|
58
|
+
|
59
|
+
def random_seeder(seed):
|
60
|
+
"""Fix randomness."""
|
61
|
+
torch.manual_seed(seed)
|
62
|
+
np.random.seed(seed)
|
63
|
+
random.seed(seed)
|
64
|
+
torch.backends.cudnn.deterministic = True
|
65
|
+
torch.backends.cudnn.benchmark = False
|
66
|
+
|
67
|
+
def pprint_config(opt):
|
68
|
+
print("\n" + "=" * 50 + " Configuration " + "=" * 50)
|
69
|
+
pprint(opt, compact=True)
|
70
|
+
print("=" * 115 + "\n")
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|