spacr 0.2.68__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. spacr/__init__.py +2 -1
  2. spacr/core.py +107 -12
  3. spacr/gui.py +3 -2
  4. spacr/gui_core.py +160 -109
  5. spacr/gui_elements.py +190 -18
  6. spacr/gui_utils.py +4 -1
  7. spacr/io.py +1 -1
  8. spacr/measure.py +4 -4
  9. spacr/mediar.py +366 -0
  10. spacr/plot.py +4 -1
  11. spacr/resources/MEDIAR/.git +1 -0
  12. spacr/resources/MEDIAR/.gitignore +18 -0
  13. spacr/resources/MEDIAR/LICENSE +21 -0
  14. spacr/resources/MEDIAR/README.md +189 -0
  15. spacr/resources/MEDIAR/SetupDict.py +39 -0
  16. spacr/resources/MEDIAR/config/baseline.json +60 -0
  17. spacr/resources/MEDIAR/config/mediar_example.json +72 -0
  18. spacr/resources/MEDIAR/config/pred/pred_mediar.json +17 -0
  19. spacr/resources/MEDIAR/config/step1_pretraining/phase1.json +55 -0
  20. spacr/resources/MEDIAR/config/step1_pretraining/phase2.json +58 -0
  21. spacr/resources/MEDIAR/config/step2_finetuning/finetuning1.json +66 -0
  22. spacr/resources/MEDIAR/config/step2_finetuning/finetuning2.json +66 -0
  23. spacr/resources/MEDIAR/config/step3_prediction/base_prediction.json +16 -0
  24. spacr/resources/MEDIAR/config/step3_prediction/ensemble_tta.json +23 -0
  25. spacr/resources/MEDIAR/core/BasePredictor.py +120 -0
  26. spacr/resources/MEDIAR/core/BaseTrainer.py +240 -0
  27. spacr/resources/MEDIAR/core/Baseline/Predictor.py +59 -0
  28. spacr/resources/MEDIAR/core/Baseline/Trainer.py +113 -0
  29. spacr/resources/MEDIAR/core/Baseline/__init__.py +2 -0
  30. spacr/resources/MEDIAR/core/Baseline/utils.py +80 -0
  31. spacr/resources/MEDIAR/core/MEDIAR/EnsemblePredictor.py +105 -0
  32. spacr/resources/MEDIAR/core/MEDIAR/Predictor.py +234 -0
  33. spacr/resources/MEDIAR/core/MEDIAR/Trainer.py +172 -0
  34. spacr/resources/MEDIAR/core/MEDIAR/__init__.py +3 -0
  35. spacr/resources/MEDIAR/core/MEDIAR/utils.py +429 -0
  36. spacr/resources/MEDIAR/core/__init__.py +2 -0
  37. spacr/resources/MEDIAR/core/utils.py +40 -0
  38. spacr/resources/MEDIAR/evaluate.py +71 -0
  39. spacr/resources/MEDIAR/generate_mapping.py +121 -0
  40. spacr/resources/MEDIAR/image/examples/img1.tiff +0 -0
  41. spacr/resources/MEDIAR/image/examples/img2.tif +0 -0
  42. spacr/resources/MEDIAR/image/failure_cases.png +0 -0
  43. spacr/resources/MEDIAR/image/mediar_framework.png +0 -0
  44. spacr/resources/MEDIAR/image/mediar_model.PNG +0 -0
  45. spacr/resources/MEDIAR/image/mediar_results.png +0 -0
  46. spacr/resources/MEDIAR/main.py +125 -0
  47. spacr/resources/MEDIAR/predict.py +70 -0
  48. spacr/resources/MEDIAR/requirements.txt +14 -0
  49. spacr/resources/MEDIAR/train_tools/__init__.py +3 -0
  50. spacr/resources/MEDIAR/train_tools/data_utils/__init__.py +1 -0
  51. spacr/resources/MEDIAR/train_tools/data_utils/custom/CellAware.py +88 -0
  52. spacr/resources/MEDIAR/train_tools/data_utils/custom/LoadImage.py +161 -0
  53. spacr/resources/MEDIAR/train_tools/data_utils/custom/NormalizeImage.py +77 -0
  54. spacr/resources/MEDIAR/train_tools/data_utils/custom/__init__.py +3 -0
  55. spacr/resources/MEDIAR/train_tools/data_utils/custom/modalities.pkl +0 -0
  56. spacr/resources/MEDIAR/train_tools/data_utils/datasetter.py +208 -0
  57. spacr/resources/MEDIAR/train_tools/data_utils/transforms.py +148 -0
  58. spacr/resources/MEDIAR/train_tools/data_utils/utils.py +84 -0
  59. spacr/resources/MEDIAR/train_tools/measures.py +200 -0
  60. spacr/resources/MEDIAR/train_tools/models/MEDIARFormer.py +102 -0
  61. spacr/resources/MEDIAR/train_tools/models/__init__.py +1 -0
  62. spacr/resources/MEDIAR/train_tools/utils.py +70 -0
  63. spacr/resources/MEDIAR_weights/.DS_Store +0 -0
  64. spacr/resources/icons/.DS_Store +0 -0
  65. spacr/resources/icons/plaque.png +0 -0
  66. spacr/resources/images/plate1_E01_T0001F001L01A01Z01C02.tif +0 -0
  67. spacr/resources/images/plate1_E01_T0001F001L01A02Z01C01.tif +0 -0
  68. spacr/resources/images/plate1_E01_T0001F001L01A03Z01C03.tif +0 -0
  69. spacr/sequencing.py +234 -422
  70. spacr/settings.py +16 -10
  71. spacr/utils.py +14 -11
  72. {spacr-0.2.68.dist-info → spacr-0.3.0.dist-info}/METADATA +10 -2
  73. {spacr-0.2.68.dist-info → spacr-0.3.0.dist-info}/RECORD +77 -18
  74. {spacr-0.2.68.dist-info → spacr-0.3.0.dist-info}/LICENSE +0 -0
  75. {spacr-0.2.68.dist-info → spacr-0.3.0.dist-info}/WHEEL +0 -0
  76. {spacr-0.2.68.dist-info → spacr-0.3.0.dist-info}/entry_points.txt +0 -0
  77. {spacr-0.2.68.dist-info → spacr-0.3.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,102 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from segmentation_models_pytorch import MAnet
5
+ from segmentation_models_pytorch.base.modules import Activation
6
+
7
+ __all__ = ["MEDIARFormer"]
8
+
9
+
10
+ class MEDIARFormer(MAnet):
11
+ """MEDIAR-Former Model"""
12
+
13
+ def __init__(
14
+ self,
15
+ encoder_name="mit_b5", # Default encoder
16
+ encoder_weights="imagenet", # Pre-trained weights
17
+ decoder_channels=(1024, 512, 256, 128, 64), # Decoder configuration
18
+ decoder_pab_channels=256, # Decoder Pyramid Attention Block channels
19
+ in_channels=3, # Number of input channels
20
+ classes=3, # Number of output classes
21
+ ):
22
+ # Initialize the MAnet model with provided parameters
23
+ super().__init__(
24
+ encoder_name=encoder_name,
25
+ encoder_weights=encoder_weights,
26
+ decoder_channels=decoder_channels,
27
+ decoder_pab_channels=decoder_pab_channels,
28
+ in_channels=in_channels,
29
+ classes=classes,
30
+ )
31
+
32
+ # Remove the default segmentation head as it's not used in this architecture
33
+ self.segmentation_head = None
34
+
35
+ # Modify all activation functions in the encoder and decoder from ReLU to Mish
36
+ _convert_activations(self.encoder, nn.ReLU, nn.Mish(inplace=True))
37
+ _convert_activations(self.decoder, nn.ReLU, nn.Mish(inplace=True))
38
+
39
+ # Add custom segmentation heads for different segmentation tasks
40
+ self.cellprob_head = DeepSegmentationHead(
41
+ in_channels=decoder_channels[-1], out_channels=1
42
+ )
43
+ self.gradflow_head = DeepSegmentationHead(
44
+ in_channels=decoder_channels[-1], out_channels=2
45
+ )
46
+
47
+ def forward(self, x):
48
+ """Forward pass through the network"""
49
+ # Ensure the input shape is correct
50
+ self.check_input_shape(x)
51
+
52
+ # Encode the input and then decode it
53
+ features = self.encoder(x)
54
+ decoder_output = self.decoder(*features)
55
+
56
+ # Generate masks for cell probability and gradient flows
57
+ cellprob_mask = self.cellprob_head(decoder_output)
58
+ gradflow_mask = self.gradflow_head(decoder_output)
59
+
60
+ # Concatenate the masks for output
61
+ masks = torch.cat([gradflow_mask, cellprob_mask], dim=1)
62
+
63
+ return masks
64
+
65
+
66
+ class DeepSegmentationHead(nn.Sequential):
67
+ """Custom segmentation head for generating specific masks"""
68
+
69
+ def __init__(
70
+ self, in_channels, out_channels, kernel_size=3, activation=None, upsampling=1
71
+ ):
72
+ # Define a sequence of layers for the segmentation head
73
+ layers = [
74
+ nn.Conv2d(
75
+ in_channels,
76
+ in_channels // 2,
77
+ kernel_size=kernel_size,
78
+ padding=kernel_size // 2,
79
+ ),
80
+ nn.Mish(inplace=True),
81
+ nn.BatchNorm2d(in_channels // 2),
82
+ nn.Conv2d(
83
+ in_channels // 2,
84
+ out_channels,
85
+ kernel_size=kernel_size,
86
+ padding=kernel_size // 2,
87
+ ),
88
+ nn.UpsamplingBilinear2d(scale_factor=upsampling)
89
+ if upsampling > 1
90
+ else nn.Identity(),
91
+ Activation(activation) if activation else nn.Identity(),
92
+ ]
93
+ super().__init__(*layers)
94
+
95
+
96
+ def _convert_activations(module, from_activation, to_activation):
97
+ """Recursively convert activation functions in a module"""
98
+ for name, child in module.named_children():
99
+ if isinstance(child, from_activation):
100
+ setattr(module, name, to_activation)
101
+ else:
102
+ _convert_activations(child, from_activation, to_activation)
@@ -0,0 +1 @@
1
+ from .MEDIARFormer import *
@@ -0,0 +1,70 @@
1
+ import torch
2
+ import numpy as np
3
+ import os, json, random
4
+ from pprint import pprint
5
+
6
+ __all__ = ["ConfLoader", "directory_setter", "random_seeder", "pprint_config"]
7
+
8
+
9
+ class ConfLoader:
10
+ """
11
+ Load json config file using DictWithAttributeAccess object_hook.
12
+ ConfLoader(conf_name).opt attribute is the result of loading json config file.
13
+ """
14
+
15
+ class DictWithAttributeAccess(dict):
16
+ """
17
+ This inner class makes dict to be accessed same as class attribute.
18
+ For example, you can use opt.key instead of the opt['key'].
19
+ """
20
+
21
+ def __getattr__(self, key):
22
+ return self[key]
23
+
24
+ def __setattr__(self, key, value):
25
+ self[key] = value
26
+
27
+ def __init__(self, conf_name):
28
+ self.conf_name = conf_name
29
+ self.opt = self.__get_opt()
30
+
31
+ def __load_conf(self):
32
+ with open(self.conf_name, "r") as conf:
33
+ opt = json.load(
34
+ conf, object_hook=lambda dict: self.DictWithAttributeAccess(dict)
35
+ )
36
+ return opt
37
+
38
+ def __get_opt(self):
39
+ opt = self.__load_conf()
40
+ opt = self.DictWithAttributeAccess(opt)
41
+
42
+ return opt
43
+
44
+
45
+ def directory_setter(path="./results", make_dir=False):
46
+ """
47
+ Make dictionary if not exists.
48
+ """
49
+ if not os.path.exists(path) and make_dir:
50
+ os.makedirs(path) # make dir if not exist
51
+ print("directory %s is created" % path)
52
+
53
+ if not os.path.isdir(path):
54
+ raise NotADirectoryError(
55
+ "%s is not valid. set make_dir=True to make dir." % path
56
+ )
57
+
58
+
59
+ def random_seeder(seed):
60
+ """Fix randomness."""
61
+ torch.manual_seed(seed)
62
+ np.random.seed(seed)
63
+ random.seed(seed)
64
+ torch.backends.cudnn.deterministic = True
65
+ torch.backends.cudnn.benchmark = False
66
+
67
+ def pprint_config(opt):
68
+ print("\n" + "=" * 50 + " Configuration " + "=" * 50)
69
+ pprint(opt, compact=True)
70
+ print("=" * 115 + "\n")
Binary file
Binary file
Binary file