neuro-sam 0.1.7__tar.gz → 0.1.8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {neuro_sam-0.1.7/src/neuro_sam.egg-info → neuro_sam-0.1.8}/PKG-INFO +5 -1
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/pyproject.toml +7 -2
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/neuro_sam/napari_utils/path_tracing_module.py +4 -4
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/neuro_sam/napari_utils/punet_widget.py +3 -1
- neuro_sam-0.1.8/src/neuro_sam/training/__init__.py +1 -0
- neuro_sam-0.1.8/src/neuro_sam/training/train_dendrites.py +226 -0
- neuro_sam-0.1.8/src/neuro_sam/training/utils/__init__.py +1 -0
- neuro_sam-0.1.8/src/neuro_sam/training/utils/prompt_generation_dendrites.py +78 -0
- neuro_sam-0.1.8/src/neuro_sam/training/utils/stream_dendrites.py +299 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8/src/neuro_sam.egg-info}/PKG-INFO +5 -1
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/neuro_sam.egg-info/SOURCES.txt +5 -1
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/neuro_sam.egg-info/entry_points.txt +1 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/neuro_sam.egg-info/requires.txt +4 -0
- neuro_sam-0.1.7/src/neuro_sam/punet/__init__.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/LICENSE +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/README.md +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/setup.cfg +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/neuro_sam/__init__.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/__init__.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/algorithm/__init__.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/algorithm/astar.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/algorithm/waypointastar.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/algorithm/waypointastar_speedup.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/connected_componen.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/cost/__init__.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/cost/cost.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/cost/reciprocal.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/cost/reciprocal_transonic.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/heuristic/__init__.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/heuristic/euclidean.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/heuristic/heuristic.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/image/__init__.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/image/stats.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/input/__init__.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/input/inputs.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/node/__init__.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/node/bidirectional_node.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/node/node.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/visualization/__init__.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/visualization/flythrough.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/visualization/flythrough_all.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/visualization/tube_data.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/visualization/tube_flythrough.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/neuro_sam/napari_utils/anisotropic_scaling.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/neuro_sam/napari_utils/color_utils.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/neuro_sam/napari_utils/contrasting_color_system.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/neuro_sam/napari_utils/main_widget.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/neuro_sam/napari_utils/segmentation_model.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/neuro_sam/napari_utils/segmentation_module.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/neuro_sam/napari_utils/visualization_module.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/neuro_sam/plugin.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/neuro_sam/punet/deepd3_model.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/neuro_sam/punet/prob_unet_deepd3.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/neuro_sam/punet/prob_unet_with_tversky.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/neuro_sam/punet/punet_inference.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/neuro_sam/punet/run_inference.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/neuro_sam/punet/unet_blocks.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/neuro_sam/punet/utils.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/neuro_sam/utils.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/neuro_sam.egg-info/dependency_links.txt +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/neuro_sam.egg-info/top_level.txt +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/sam2/__init__.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/sam2/automatic_mask_generator.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/sam2/benchmark.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/sam2/build_sam.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/sam2/configs/sam2/sam2_hiera_b+.yaml +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/sam2/configs/sam2/sam2_hiera_l.yaml +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/sam2/configs/sam2/sam2_hiera_s.yaml +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/sam2/configs/sam2/sam2_hiera_t.yaml +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/sam2/configs/sam2.1/sam2.1_hiera_l.yaml +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/sam2/configs/sam2.1/sam2.1_hiera_s.yaml +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/sam2/configs/sam2.1/sam2.1_hiera_t.yaml +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/sam2/configs/train.yaml +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/sam2/modeling/__init__.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/sam2/modeling/backbones/__init__.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/sam2/modeling/backbones/hieradet.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/sam2/modeling/backbones/image_encoder.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/sam2/modeling/backbones/utils.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/sam2/modeling/memory_attention.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/sam2/modeling/memory_encoder.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/sam2/modeling/position_encoding.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/sam2/modeling/sam/__init__.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/sam2/modeling/sam/mask_decoder.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/sam2/modeling/sam/prompt_encoder.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/sam2/modeling/sam/transformer.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/sam2/modeling/sam2_base.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/sam2/modeling/sam2_utils.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/sam2/sam2.1_hiera_b+.yaml +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/sam2/sam2.1_hiera_l.yaml +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/sam2/sam2.1_hiera_s.yaml +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/sam2/sam2.1_hiera_t.yaml +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/sam2/sam2_hiera_b+.yaml +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/sam2/sam2_hiera_l.yaml +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/sam2/sam2_hiera_s.yaml +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/sam2/sam2_hiera_t.yaml +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/sam2/sam2_image_predictor.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/sam2/sam2_video_predictor.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/sam2/sam2_video_predictor_legacy.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/sam2/utils/__init__.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/sam2/utils/amg.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/sam2/utils/misc.py +0 -0
- {neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/sam2/utils/transforms.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: neuro-sam
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.8
|
|
4
4
|
Summary: Neuro-SAM: Foundation Models for Dendrite and Dendritic Spine Segmentation
|
|
5
5
|
Author-email: Nipun Arora <nipunarora8@yahoo.com>
|
|
6
6
|
License: MIT License
|
|
@@ -55,6 +55,10 @@ Requires-Dist: PyQt5
|
|
|
55
55
|
Requires-Dist: opencv-python-headless
|
|
56
56
|
Requires-Dist: matplotlib
|
|
57
57
|
Requires-Dist: requests
|
|
58
|
+
Requires-Dist: flammkuchen
|
|
59
|
+
Requires-Dist: albumentations
|
|
60
|
+
Requires-Dist: wandb
|
|
61
|
+
Requires-Dist: tensorflow
|
|
58
62
|
Dynamic: license-file
|
|
59
63
|
|
|
60
64
|
<div align="center">
|
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "neuro-sam"
|
|
7
|
-
version = "0.1.
|
|
7
|
+
version = "0.1.8"
|
|
8
8
|
description = "Neuro-SAM: Foundation Models for Dendrite and Dendritic Spine Segmentation"
|
|
9
9
|
readme = "README.md"
|
|
10
10
|
authors = [
|
|
@@ -38,7 +38,11 @@ dependencies = [
|
|
|
38
38
|
"PyQt5",
|
|
39
39
|
"opencv-python-headless",
|
|
40
40
|
"matplotlib",
|
|
41
|
-
"requests"
|
|
41
|
+
"requests",
|
|
42
|
+
"flammkuchen",
|
|
43
|
+
"albumentations",
|
|
44
|
+
"wandb",
|
|
45
|
+
"tensorflow"
|
|
42
46
|
]
|
|
43
47
|
requires-python = ">=3.10"
|
|
44
48
|
|
|
@@ -57,6 +61,7 @@ include = ["neuro_sam*", "sam2*"]
|
|
|
57
61
|
[project.scripts]
|
|
58
62
|
neuro-sam = "neuro_sam.plugin:main"
|
|
59
63
|
neuro-sam-download = "neuro_sam.utils:download_all_models"
|
|
64
|
+
neuro-sam-train = "neuro_sam.training.train_dendrites:main"
|
|
60
65
|
|
|
61
66
|
[project.entry-points."napari.manifest"]
|
|
62
67
|
neuro-sam = "neuro_sam:napari.yaml"
|
|
@@ -509,10 +509,10 @@ class PathTracingWidget(QWidget):
|
|
|
509
509
|
self.scaler.original_spacing_xyz[0] / x_nm # X
|
|
510
510
|
])
|
|
511
511
|
|
|
512
|
-
self.scaling_status.setText(
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
)
|
|
512
|
+
# self.scaling_status.setText(
|
|
513
|
+
# f"Pending: X={x_nm:.1f}, Y={y_nm:.1f}, Z={z_nm:.1f} nm\n"
|
|
514
|
+
# f"Scale factors (Z,Y,X): {temp_scale_factors[0]:.3f}, {temp_scale_factors[1]:.3f}, {temp_scale_factors[2]:.3f}"
|
|
515
|
+
# )
|
|
516
516
|
|
|
517
517
|
def _apply_scaling(self):
|
|
518
518
|
"""Apply current scaling settings"""
|
|
@@ -21,6 +21,8 @@ from neuro_sam.punet.punet_inference import run_inference_volume
|
|
|
21
21
|
|
|
22
22
|
|
|
23
23
|
|
|
24
|
+
from neuro_sam.utils import get_weights_path
|
|
25
|
+
|
|
24
26
|
class PunetSpineSegmentationWidget(QWidget):
|
|
25
27
|
"""
|
|
26
28
|
Widget for spine segmentation using Probabilistic U-Net.
|
|
@@ -36,7 +38,7 @@ class PunetSpineSegmentationWidget(QWidget):
|
|
|
36
38
|
|
|
37
39
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
38
40
|
self.model = None
|
|
39
|
-
self.model_path = "
|
|
41
|
+
self.model_path = get_weights_path("punet_best.pth") # Auto-download default weights
|
|
40
42
|
|
|
41
43
|
# Connect custom progress signal
|
|
42
44
|
self.progress_signal.connect(self._on_worker_progress)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# Init file for training module
|
|
@@ -0,0 +1,226 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import torch
|
|
3
|
+
import os
|
|
4
|
+
import wandb
|
|
5
|
+
from torch.onnx.symbolic_opset11 import hstack
|
|
6
|
+
import torch.nn.functional as F
|
|
7
|
+
from .utils.stream_dendrites import DataGeneratorStream
|
|
8
|
+
from sam2.build_sam import build_sam2
|
|
9
|
+
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
|
10
|
+
import time
|
|
11
|
+
import argparse
|
|
12
|
+
from neuro_sam.utils import get_weights_path
|
|
13
|
+
|
|
14
|
+
def main():
|
|
15
|
+
parser = argparse.ArgumentParser(description="Train Neuro-SAM Dendrite Segmenter")
|
|
16
|
+
parser.add_argument("--ppn", type=int, required=True, help="Positive Points Number")
|
|
17
|
+
parser.add_argument("--pnn", type=int, required=True, help="Negative Points Number")
|
|
18
|
+
parser.add_argument("--model_name", type=str, required=True, choices=['small', 'base_plus', 'large', 'tiny'], help="SAM2 Model Size")
|
|
19
|
+
parser.add_argument("--batch_size", type=int, required=True, help="Batch Size")
|
|
20
|
+
parser.add_argument("--logger", type=str, default="False", help="Use WandB Logger (True/False)")
|
|
21
|
+
parser.add_argument("--data_dir", type=str, default="./data", help="Directory containing .d3set data files")
|
|
22
|
+
args = parser.parse_args()
|
|
23
|
+
|
|
24
|
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
25
|
+
print(f"Using device: {device}")
|
|
26
|
+
|
|
27
|
+
TRAINING_DATA_PATH = os.path.join(args.data_dir, "DeepD3_Training.d3set")
|
|
28
|
+
VALIDATION_DATA_PATH = os.path.join(args.data_dir, "DeepD3_Validation.d3set")
|
|
29
|
+
|
|
30
|
+
if not os.path.exists(TRAINING_DATA_PATH):
|
|
31
|
+
raise FileNotFoundError(f"Training data not found at {TRAINING_DATA_PATH}")
|
|
32
|
+
|
|
33
|
+
positive_points = args.ppn
|
|
34
|
+
negative_points = args.pnn
|
|
35
|
+
batch_size = args.batch_size
|
|
36
|
+
logger = (args.logger.lower() == "true")
|
|
37
|
+
|
|
38
|
+
print("Initializing Data Generator...")
|
|
39
|
+
dg_training = DataGeneratorStream(TRAINING_DATA_PATH,
|
|
40
|
+
batch_size=batch_size, # Data processed at once, depends on your GPU
|
|
41
|
+
target_resolution=0.094, # fixed to 94 nm, can be None for mixed resolution training
|
|
42
|
+
min_content=100,
|
|
43
|
+
resizing_size = 128,
|
|
44
|
+
positive_points = positive_points,
|
|
45
|
+
negative_points = negative_points) # images need to have at least 50 segmented px
|
|
46
|
+
|
|
47
|
+
dg_validation = DataGeneratorStream(VALIDATION_DATA_PATH,
|
|
48
|
+
batch_size=batch_size,
|
|
49
|
+
target_resolution=0.094,
|
|
50
|
+
min_content=100,
|
|
51
|
+
augment=False,
|
|
52
|
+
shuffle=False,
|
|
53
|
+
resizing_size = 128,
|
|
54
|
+
positive_points = positive_points,
|
|
55
|
+
negative_points = negative_points)
|
|
56
|
+
|
|
57
|
+
# Load model
|
|
58
|
+
model_name_map = {
|
|
59
|
+
'large': ("sam2.1_hiera_large.pt", "sam2.1_hiera_l.yaml"),
|
|
60
|
+
'base_plus': ("sam2.1_hiera_base_plus.pt", "sam2.1_hiera_b+.yaml"),
|
|
61
|
+
'small': ("sam2.1_hiera_small.pt", "sam2.1_hiera_s.yaml"),
|
|
62
|
+
'tiny': ("sam2.1_hiera_tiny.pt", "sam2.1_hiera_t.yaml"),
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
ckpt_name, model_cfg = model_name_map[args.model_name]
|
|
66
|
+
|
|
67
|
+
# Try to find weights via util or local defaults
|
|
68
|
+
try:
|
|
69
|
+
sam2_checkpoint = get_weights_path(ckpt_name)
|
|
70
|
+
except:
|
|
71
|
+
# Fallback if not downloadable or found
|
|
72
|
+
sam2_checkpoint = ckpt_name
|
|
73
|
+
print(f"Warning: Could not resolve weight path for {ckpt_name}, assuming local file.")
|
|
74
|
+
|
|
75
|
+
print(f"Loading SAM2 model: {args.model_name} from {sam2_checkpoint}")
|
|
76
|
+
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device) # load model
|
|
77
|
+
|
|
78
|
+
predictor = SAM2ImagePredictor(sam2_model)
|
|
79
|
+
|
|
80
|
+
# Set training parameters
|
|
81
|
+
|
|
82
|
+
predictor.model.sam_mask_decoder.train(True) # enable training of mask decoder
|
|
83
|
+
predictor.model.sam_prompt_encoder.train(True) # enable training of prompt encoder
|
|
84
|
+
predictor.model.image_encoder.train(True) # enable training of image encoder: For this to work you need to scan the code for "no_grad" and remove them all
|
|
85
|
+
optimizer=torch.optim.AdamW(params=predictor.model.parameters(),lr=1e-5,weight_decay=4e-5)
|
|
86
|
+
scaler = torch.cuda.amp.GradScaler() # mixed precision
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
time_str = "-".join(["{:0>2}".format(x) for x in time.localtime(time.time())][:-3])
|
|
90
|
+
ckpt_path = f'results/samv2_{args.model_name}_{time_str}'
|
|
91
|
+
if not os.path.exists(ckpt_path): os.makedirs(ckpt_path)
|
|
92
|
+
|
|
93
|
+
if logger:
|
|
94
|
+
wandb.init(
|
|
95
|
+
# set the wandb project where this run will be logged
|
|
96
|
+
project="N-SAMv2",
|
|
97
|
+
|
|
98
|
+
# track hyperparameters and run metadata
|
|
99
|
+
config={
|
|
100
|
+
"architecture": "SAMv2",
|
|
101
|
+
"dataset": "DeepD3",
|
|
102
|
+
"model": args.model_name,
|
|
103
|
+
"epochs": 100000,
|
|
104
|
+
"ckpt_path": ckpt_path,
|
|
105
|
+
"image_size": (1,128,128),
|
|
106
|
+
"min_content": 100,
|
|
107
|
+
"positive_points": positive_points,
|
|
108
|
+
"negative_points": negative_points,
|
|
109
|
+
"batch_size": batch_size,
|
|
110
|
+
"prompt_seed":42
|
|
111
|
+
}
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
# add val code here
|
|
116
|
+
|
|
117
|
+
def perform_validation(predictor):
|
|
118
|
+
print('Performing Validation')
|
|
119
|
+
mean_iou = []
|
|
120
|
+
mean_dice = []
|
|
121
|
+
mean_loss = []
|
|
122
|
+
with torch.no_grad():
|
|
123
|
+
for i in range(20):
|
|
124
|
+
try:
|
|
125
|
+
n = np.random.randint(len(dg_validation))
|
|
126
|
+
image,mask,input_point, input_label = dg_validation[n]
|
|
127
|
+
|
|
128
|
+
except:
|
|
129
|
+
print('Error in validation batch generation')
|
|
130
|
+
continue
|
|
131
|
+
if mask.shape[0]==0: continue # ignore empty batches
|
|
132
|
+
predictor.set_image_batch(image) # apply SAM image encoder to the image
|
|
133
|
+
|
|
134
|
+
mask_input, unnorm_coords, labels, unnorm_box = predictor._prep_prompts(input_point, input_label, box=None, mask_logits=None, normalize_coords=True)
|
|
135
|
+
sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(points=(unnorm_coords, labels),boxes=None,masks=None,)
|
|
136
|
+
|
|
137
|
+
# mask decoder
|
|
138
|
+
|
|
139
|
+
high_res_features = [feat_level[-1].unsqueeze(0) for feat_level in predictor._features["high_res_feats"]]
|
|
140
|
+
low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder(image_embeddings=predictor._features["image_embed"],image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(),sparse_prompt_embeddings=sparse_embeddings,dense_prompt_embeddings=dense_embeddings,multimask_output=True,repeat_image=False,high_res_features=high_res_features,)
|
|
141
|
+
prd_masks = predictor._transforms.postprocess_masks(low_res_masks, predictor._orig_hw[-1])# Upscale the masks to the original image resolution
|
|
142
|
+
|
|
143
|
+
# Segmentaion Loss caclulation
|
|
144
|
+
|
|
145
|
+
gt_mask = torch.tensor(mask.astype(np.float32)).to(device)
|
|
146
|
+
prd_mask = torch.sigmoid(prd_masks[:, 0])# Turn logit map to probability map
|
|
147
|
+
seg_loss = (-gt_mask * torch.log(prd_mask + 0.00001) - (1 - gt_mask) * torch.log((1 - prd_mask) + 0.00001)).mean() # cross entropy loss
|
|
148
|
+
|
|
149
|
+
# Score loss calculation (intersection over union) IOU
|
|
150
|
+
|
|
151
|
+
inter = (gt_mask * (prd_mask > 0.5)).sum(1).sum(1)
|
|
152
|
+
iou = inter / (gt_mask.sum(1).sum(1) + (prd_mask > 0.5).sum(1).sum(1) - inter)
|
|
153
|
+
total_sum = gt_mask.sum(1).sum(1) + (prd_mask > 0.5).sum(1).sum(1)
|
|
154
|
+
dicee = ((2 * inter) / total_sum).mean()
|
|
155
|
+
score_loss = torch.abs(prd_scores[:, 0] - iou).mean()
|
|
156
|
+
loss=seg_loss+score_loss*0.05 # mix losses
|
|
157
|
+
|
|
158
|
+
mean_iou.append(iou.cpu().detach().numpy())
|
|
159
|
+
mean_dice.append(dicee.cpu().detach().numpy())
|
|
160
|
+
mean_loss.append(loss.cpu().detach().numpy())
|
|
161
|
+
|
|
162
|
+
return np.array(mean_iou).mean(), np.array(mean_dice).mean(), np.array(mean_loss).mean()
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
for itr in range(100000):
|
|
166
|
+
epoch_dice = epoch_iou = epoch_loss = []
|
|
167
|
+
with torch.cuda.amp.autocast(): # cast to mix precision
|
|
168
|
+
n = np.random.randint(len(dg_training))
|
|
169
|
+
try:
|
|
170
|
+
image,mask,input_point, input_label = dg_training[n]
|
|
171
|
+
except:
|
|
172
|
+
print('Error in training batch')
|
|
173
|
+
continue
|
|
174
|
+
if mask.shape[0]==0: continue # ignore empty batches
|
|
175
|
+
predictor.set_image_batch(image) # apply SAM image encoder to the image
|
|
176
|
+
|
|
177
|
+
# prompt encoding
|
|
178
|
+
|
|
179
|
+
mask_input, unnorm_coords, labels, unnorm_box = predictor._prep_prompts(input_point, input_label, box=None, mask_logits=None, normalize_coords=True)
|
|
180
|
+
sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(points=(unnorm_coords, labels),boxes=None,masks=None,)
|
|
181
|
+
|
|
182
|
+
# mask decoder
|
|
183
|
+
|
|
184
|
+
high_res_features = [feat_level[-1].unsqueeze(0) for feat_level in predictor._features["high_res_feats"]]
|
|
185
|
+
low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder(image_embeddings=predictor._features["image_embed"],image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(),sparse_prompt_embeddings=sparse_embeddings,dense_prompt_embeddings=dense_embeddings,multimask_output=True,repeat_image=False,high_res_features=high_res_features,)
|
|
186
|
+
prd_masks = predictor._transforms.postprocess_masks(low_res_masks, predictor._orig_hw[-1])# Upscale the masks to the original image resolution
|
|
187
|
+
|
|
188
|
+
gt_mask = torch.tensor(mask.astype(np.float32)).to(device)
|
|
189
|
+
prd_mask = torch.sigmoid(prd_masks[:, 0])# Turn logit map to probability map
|
|
190
|
+
seg_loss = (-gt_mask * torch.log(prd_mask + 0.00001) - (1 - gt_mask) * torch.log((1 - prd_mask) + 0.00001)).mean() # cross entropy loss
|
|
191
|
+
|
|
192
|
+
# Score loss calculation (intersection over union) IOU
|
|
193
|
+
|
|
194
|
+
inter = (gt_mask * (prd_mask > 0.5)).sum(1).sum(1)
|
|
195
|
+
iou = inter / (gt_mask.sum(1).sum(1) + (prd_mask > 0.5).sum(1).sum(1) - inter)
|
|
196
|
+
total_sum = gt_mask.sum(1).sum(1) + (prd_mask > 0.5).sum(1).sum(1)
|
|
197
|
+
dicee = ((2 * inter) / total_sum).mean()
|
|
198
|
+
score_loss = torch.abs(prd_scores[:, 0] - iou).mean()
|
|
199
|
+
loss=seg_loss+score_loss*0.05 # mix losses
|
|
200
|
+
|
|
201
|
+
predictor.model.zero_grad() # empty gradient
|
|
202
|
+
scaler.scale(loss).backward() # Backpropogate
|
|
203
|
+
scaler.step(optimizer)
|
|
204
|
+
scaler.update() # Mix precision
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
# Display results
|
|
208
|
+
if itr==0: mean_iou=0
|
|
209
|
+
mean_iou = mean_iou * 0.99 + 0.01 * np.mean(iou.cpu().detach().numpy())
|
|
210
|
+
if logger:
|
|
211
|
+
wandb.log({'step': itr, 'loss':loss, 'iou':mean_iou, 'dice_score':dicee})
|
|
212
|
+
print(f'step: {itr}, loss: {loss}, iou: {mean_iou}, dice_score: {dicee}')
|
|
213
|
+
|
|
214
|
+
if itr%500 == 0:
|
|
215
|
+
|
|
216
|
+
val_iou, val_dice, val_loss = perform_validation(predictor)
|
|
217
|
+
if logger:
|
|
218
|
+
wandb.log({'step': itr, 'val_loss':val_loss, 'val_iou':val_iou, 'val_dice_score':val_dice})
|
|
219
|
+
print(f'step: {itr}, val_loss: {val_loss}, val_iou: {val_iou}, val_dice_score: {val_dice}')
|
|
220
|
+
|
|
221
|
+
torch.save(predictor.model.state_dict(), f"{ckpt_path}/model_{itr}.torch") # save model
|
|
222
|
+
if logger:
|
|
223
|
+
wandb.finish()
|
|
224
|
+
|
|
225
|
+
if __name__ == "__main__":
|
|
226
|
+
main()
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# Init file for training utils
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from scipy import ndimage
|
|
3
|
+
import cv2, random
|
|
4
|
+
from collections import *
|
|
5
|
+
from itertools import *
|
|
6
|
+
from functools import *
|
|
7
|
+
|
|
8
|
+
class PromptGeneration:
|
|
9
|
+
def __init__(self, random_seed=0, neg_range=(3, 20), min_content = 20):
|
|
10
|
+
if random_seed:
|
|
11
|
+
random.seed(random_seed)
|
|
12
|
+
np.random.seed(random_seed)
|
|
13
|
+
|
|
14
|
+
self.neg_range = neg_range
|
|
15
|
+
|
|
16
|
+
self.min_area = min_content
|
|
17
|
+
|
|
18
|
+
def get_labelmap(self, label):
|
|
19
|
+
structure = ndimage.generate_binary_structure(2, 2)
|
|
20
|
+
labelmaps, connected_num = ndimage.label(label, structure=structure)
|
|
21
|
+
|
|
22
|
+
label = np.zeros_like(labelmaps)
|
|
23
|
+
for i in range(1, 1+connected_num):
|
|
24
|
+
if np.sum(labelmaps==i) >= self.min_area: label += np.where(labelmaps==i, 255, 0)
|
|
25
|
+
|
|
26
|
+
structure = ndimage.generate_binary_structure(2, 2)
|
|
27
|
+
labelmaps, connected_num = ndimage.label(label, structure=structure)
|
|
28
|
+
|
|
29
|
+
return labelmaps, connected_num
|
|
30
|
+
|
|
31
|
+
def search_negative_region_numpy(self, labelmap):
|
|
32
|
+
inner_range, outer_range = self.neg_range
|
|
33
|
+
def search(neg_range):
|
|
34
|
+
kernel = np.ones((neg_range * 2 + 1, neg_range * 2 + 1), np.uint8)
|
|
35
|
+
negative_region = cv2.dilate(labelmap, kernel, iterations=1)
|
|
36
|
+
mx = labelmap.max() + 1
|
|
37
|
+
labelmap_r = (mx - labelmap) * np.minimum(1, labelmap)
|
|
38
|
+
r = cv2.dilate(labelmap_r, kernel, iterations=1)
|
|
39
|
+
negative_region_r = (r.astype(np.int32) - mx) * np.minimum(1, r)
|
|
40
|
+
diff = negative_region.astype(np.int32) + negative_region_r
|
|
41
|
+
overlap = np.minimum(1, np.abs(diff).astype(np.uint8))
|
|
42
|
+
return negative_region - overlap - labelmap
|
|
43
|
+
return search(outer_range) - search(inner_range)
|
|
44
|
+
|
|
45
|
+
def get_prompt_points(self, label_mask, ppp, ppn):
|
|
46
|
+
label_mask_cp = np.copy(label_mask)
|
|
47
|
+
label_mask_cp[label_mask_cp >= 1] = 1
|
|
48
|
+
labelmaps, connected_num = self.get_labelmap(label_mask_cp)
|
|
49
|
+
|
|
50
|
+
coord_positive, coord_negative = [], []
|
|
51
|
+
|
|
52
|
+
connected_components = list(range(1, connected_num+1))
|
|
53
|
+
random.shuffle(connected_components)
|
|
54
|
+
|
|
55
|
+
for i in connected_components:
|
|
56
|
+
cc = np.copy(labelmaps)
|
|
57
|
+
cc[cc!=i] = 0
|
|
58
|
+
cc[cc==i] = 1
|
|
59
|
+
if ppp:
|
|
60
|
+
coord_positive.append(random.choice([[y, x] for x, y in np.argwhere(cc == 1)]))
|
|
61
|
+
ppp -= 1
|
|
62
|
+
|
|
63
|
+
random.shuffle(connected_components)
|
|
64
|
+
for i in connected_components:
|
|
65
|
+
cc = np.copy(labelmaps)
|
|
66
|
+
cc[cc!=i] = 0
|
|
67
|
+
cc[cc==i] = 1
|
|
68
|
+
negative_region = self.search_negative_region_numpy(cc.astype(np.uint8))
|
|
69
|
+
negative_region = negative_region * (1 - label_mask_cp)
|
|
70
|
+
if ppn:
|
|
71
|
+
coord_negative.append(random.choice([[y, x] for x, y in np.argwhere(negative_region == 1)]))
|
|
72
|
+
ppn -= 1
|
|
73
|
+
|
|
74
|
+
negative_region = self.search_negative_region_numpy(label_mask_cp)
|
|
75
|
+
|
|
76
|
+
if ppp: coord_positive += random.sample([[y, x] for x, y in np.argwhere(label_mask_cp == 1)], ppp)
|
|
77
|
+
if ppn: coord_negative += random.sample([[y, x] for x, y in np.argwhere(negative_region == 1)], ppn)
|
|
78
|
+
return coord_positive, coord_negative
|
|
@@ -0,0 +1,299 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import flammkuchen as fl
|
|
3
|
+
from tensorflow.keras.utils import Sequence
|
|
4
|
+
from .prompt_generation_dendrites import PromptGeneration
|
|
5
|
+
import cv2
|
|
6
|
+
import albumentations as A
|
|
7
|
+
import random
|
|
8
|
+
|
|
9
|
+
class DataGeneratorStream(Sequence):
|
|
10
|
+
def __init__(self, fn, batch_size, samples_per_epoch=50000, size=(1, 128, 128), target_resolution=None, augment=True,
|
|
11
|
+
shuffle=True, seed=42, normalize=[-1, 1], resizing_size=1024, min_content=0., positive_points=1, negative_points=1
|
|
12
|
+
):
|
|
13
|
+
"""Data Generator that streams data dynamically for training DeepD3.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
fn (str): The path to the training data file
|
|
17
|
+
batch_size (int): Batch size for training deep neural networks
|
|
18
|
+
samples_per_epoch (int, optional): Samples used in each epoch. Defaults to 50000.
|
|
19
|
+
size (tuple, optional): Shape of a single sample. Defaults to (1, 128, 128).
|
|
20
|
+
target_resolution (float, optional): Target resolution in microns. Defaults to None.
|
|
21
|
+
augment (bool, optional): Enables augmenting the data. Defaults to True.
|
|
22
|
+
shuffle (bool, optional): Enabled shuffling the data. Defaults to True.
|
|
23
|
+
seed (int, optional): Creates pseudorandom numbers for shuffling. Defaults to 42.
|
|
24
|
+
normalize (list, optional): Values range when normalizing data. Defaults to [-1, 1].
|
|
25
|
+
min_content (float, optional): Minimum content in image (annotated dendrite or spine), not considered if 0. Defaults to 0.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
# Save settings
|
|
29
|
+
self.batch_size = batch_size
|
|
30
|
+
self.augment = augment
|
|
31
|
+
self.fn = fn
|
|
32
|
+
self.shuffle = shuffle
|
|
33
|
+
self.aug = self._get_augmenter()
|
|
34
|
+
self.add_gaus_noise = self._add_gaus_noise()
|
|
35
|
+
self.seed = seed
|
|
36
|
+
self.normalize = normalize
|
|
37
|
+
self.samples_per_epoch = samples_per_epoch
|
|
38
|
+
self.size = size
|
|
39
|
+
self.target_resolution = target_resolution
|
|
40
|
+
self.min_content = min_content
|
|
41
|
+
self.resizing_size = resizing_size
|
|
42
|
+
self.ppn = positive_points
|
|
43
|
+
self.pnn = negative_points
|
|
44
|
+
|
|
45
|
+
self.d = fl.load(self.fn)
|
|
46
|
+
self.data = self.d['data']
|
|
47
|
+
self.meta = self.d['meta']
|
|
48
|
+
|
|
49
|
+
# Seed randomness
|
|
50
|
+
random.seed(self.seed)
|
|
51
|
+
np.random.seed(self.seed)
|
|
52
|
+
|
|
53
|
+
self.on_epoch_end()
|
|
54
|
+
self.pg = PromptGeneration(random_seed=42, neg_range=(3, 9))
|
|
55
|
+
|
|
56
|
+
def __len__(self):
|
|
57
|
+
"""Denotes the number of batches per epoch"""
|
|
58
|
+
return self.samples_per_epoch // self.batch_size
|
|
59
|
+
|
|
60
|
+
def __getitem__(self, index):
|
|
61
|
+
"""Generate one batch of data
|
|
62
|
+
|
|
63
|
+
Parameters
|
|
64
|
+
----------
|
|
65
|
+
index : int
|
|
66
|
+
batch index in image/label id list
|
|
67
|
+
|
|
68
|
+
Returns
|
|
69
|
+
-------
|
|
70
|
+
tuple
|
|
71
|
+
Contains two numpy arrays,
|
|
72
|
+
each of shape (batch_size, height, width, 1).
|
|
73
|
+
"""
|
|
74
|
+
X = []
|
|
75
|
+
Y0 = []
|
|
76
|
+
Y1 = []
|
|
77
|
+
pps = []
|
|
78
|
+
pts = []
|
|
79
|
+
eps = 1e-5
|
|
80
|
+
|
|
81
|
+
if self.shuffle is False:
|
|
82
|
+
np.random.seed(index)
|
|
83
|
+
|
|
84
|
+
# Create all pairs in a given batch
|
|
85
|
+
for i in range(self.batch_size):
|
|
86
|
+
# Retrieve a single sample pair
|
|
87
|
+
image, dendrite, spines = self.getSample()
|
|
88
|
+
image = cv2.resize(image, (self.resizing_size, self.resizing_size), interpolation=cv2.INTER_LINEAR)
|
|
89
|
+
dendrite = cv2.resize(dendrite.astype(np.uint8), (self.resizing_size, self.resizing_size), interpolation=cv2.INTER_LINEAR)
|
|
90
|
+
spines = cv2.resize(spines.astype(np.uint8), (self.resizing_size, self.resizing_size), interpolation=cv2.INTER_LINEAR)
|
|
91
|
+
|
|
92
|
+
# Augmenting the data
|
|
93
|
+
if self.augment:
|
|
94
|
+
augmented = self.aug(image=image,
|
|
95
|
+
mask1=dendrite,
|
|
96
|
+
mask2=spines) #augment image
|
|
97
|
+
|
|
98
|
+
image = augmented['image']
|
|
99
|
+
dendrite = augmented['mask1']
|
|
100
|
+
spines = augmented['mask2']
|
|
101
|
+
try:
|
|
102
|
+
pp, pt = self.extract_points(dendrite)
|
|
103
|
+
|
|
104
|
+
augmented = self.add_gaus_noise(image=image,
|
|
105
|
+
mask1=dendrite,
|
|
106
|
+
mask2=spines) #augment image
|
|
107
|
+
|
|
108
|
+
image = augmented['image']
|
|
109
|
+
dendrite = augmented['mask1']
|
|
110
|
+
spines = augmented['mask2']
|
|
111
|
+
except:
|
|
112
|
+
pass
|
|
113
|
+
|
|
114
|
+
else:
|
|
115
|
+
dendrite = dendrite
|
|
116
|
+
spines = spines
|
|
117
|
+
try:
|
|
118
|
+
pp, pt = self.extract_points(dendrite)
|
|
119
|
+
except:
|
|
120
|
+
return False
|
|
121
|
+
|
|
122
|
+
# Min/max scaling
|
|
123
|
+
image = (image.astype(np.float32) - image.min()) / (image.max() - image.min() + eps)
|
|
124
|
+
# Shifting and scaling
|
|
125
|
+
# image = image * (self.normalize[1]-self.normalize[0]) + self.normalize[0]
|
|
126
|
+
|
|
127
|
+
X.append(cv2.cvtColor(image, cv2.COLOR_GRAY2RGB))
|
|
128
|
+
Y0.append(dendrite.astype(np.float32) / (dendrite.max() + eps))
|
|
129
|
+
Y1.append(spines.astype(np.float32) / (spines.max() + eps)) # to ensure binary targets
|
|
130
|
+
pps.append(pp)
|
|
131
|
+
pts.append(pt)
|
|
132
|
+
|
|
133
|
+
# return np.asarray(X, dtype=np.float32)[..., None], (np.asarray(Y0, dtype=np.float32)[..., None],
|
|
134
|
+
# np.asarray(Y1, dtype=np.float32)[..., None]), (np.asarray(pps), np.asarray(pts))
|
|
135
|
+
return np.asarray(X, dtype=np.float32), np.asarray(Y0, dtype=np.float32), np.asarray(pps), np.asarray(pts)
|
|
136
|
+
|
|
137
|
+
def extract_points(self, dendrite):
|
|
138
|
+
ppp, ppn = self.ppn,self.pnn
|
|
139
|
+
coord_positive, coord_negative = self.pg.get_prompt_points(dendrite, ppp, ppn)
|
|
140
|
+
pp = coord_positive + coord_negative
|
|
141
|
+
pt = [1] * len(coord_positive) + [0] * len(coord_negative)
|
|
142
|
+
pp = np.array(pp, dtype=np.float32)
|
|
143
|
+
pt = np.array(pt, dtype=np.int32)
|
|
144
|
+
return pp, pt
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def _get_augmenter(self):
|
|
148
|
+
"""Defines used augmentations"""
|
|
149
|
+
aug = A.Compose([
|
|
150
|
+
A.RandomBrightnessContrast(p=0.25),
|
|
151
|
+
A.Rotate(limit=10, border_mode=cv2.BORDER_REFLECT, p=0.5),
|
|
152
|
+
A.RandomRotate90(p=0.5),
|
|
153
|
+
A.HorizontalFlip(p=0.5),
|
|
154
|
+
A.VerticalFlip(p=0.5),
|
|
155
|
+
A.Blur(p=0.2)],p=1,
|
|
156
|
+
#A.GaussNoise(p=0.5)], p=1,
|
|
157
|
+
additional_targets={
|
|
158
|
+
'mask1': 'mask',
|
|
159
|
+
'mask2': 'mask'
|
|
160
|
+
})
|
|
161
|
+
return aug
|
|
162
|
+
|
|
163
|
+
def _add_gaus_noise(self):
|
|
164
|
+
aug = A.Compose([
|
|
165
|
+
A.GaussNoise(p=0.5)], p=1,
|
|
166
|
+
additional_targets={
|
|
167
|
+
'mask1': 'mask',
|
|
168
|
+
'mask2': 'mask'
|
|
169
|
+
})
|
|
170
|
+
return aug
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def getSample(self, squeeze=True):
|
|
174
|
+
"""Get a sample from the provided data
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
squeeze (bool, optional): if plane is 2D, skip 3D. Defaults to True.
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
list(np.ndarray, np.ndarray, np.ndarray): stack image with respective labels
|
|
181
|
+
"""
|
|
182
|
+
while True:
|
|
183
|
+
r = self._getSample(squeeze)
|
|
184
|
+
|
|
185
|
+
# If sample was successfully generated
|
|
186
|
+
# and we don't care about the content
|
|
187
|
+
if r is not None and self.min_content == 0:
|
|
188
|
+
return r
|
|
189
|
+
|
|
190
|
+
# If sample was successfully generated
|
|
191
|
+
# and we do care about the content
|
|
192
|
+
elif r is not None:
|
|
193
|
+
# In either or both annotation should be at least `min_content` pixels
|
|
194
|
+
# that are being labelled.
|
|
195
|
+
if (r[1]).sum() > self.min_content and (r[2]).sum() > self.min_content:
|
|
196
|
+
return r
|
|
197
|
+
else:
|
|
198
|
+
continue
|
|
199
|
+
|
|
200
|
+
else:
|
|
201
|
+
continue
|
|
202
|
+
|
|
203
|
+
def _getSample(self, squeeze=True):
|
|
204
|
+
"""Retrieves a sample
|
|
205
|
+
|
|
206
|
+
Args:
|
|
207
|
+
squeeze (bool, optional): Squeezes return shape. Defaults to True.
|
|
208
|
+
|
|
209
|
+
Returns:
|
|
210
|
+
tuple: Tuple of stack (X), dendrite (Y0) and spines (Y1)
|
|
211
|
+
"""
|
|
212
|
+
# Adjust for 2 images
|
|
213
|
+
if len(self.size) == 2:
|
|
214
|
+
size = (1,) + self.size
|
|
215
|
+
|
|
216
|
+
else:
|
|
217
|
+
size = self.size
|
|
218
|
+
|
|
219
|
+
# sample random stack
|
|
220
|
+
r_stack = np.random.choice(len(self.meta))
|
|
221
|
+
|
|
222
|
+
target_h = size[1]
|
|
223
|
+
target_w = size[2]
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
if self.target_resolution is None:
|
|
227
|
+
# Keep everything as is
|
|
228
|
+
scaling = 1
|
|
229
|
+
h = target_h
|
|
230
|
+
w = target_w
|
|
231
|
+
|
|
232
|
+
else:
|
|
233
|
+
# Computing scaling factor
|
|
234
|
+
scaling = self.target_resolution / self.meta.iloc[r_stack].Resolution_XY
|
|
235
|
+
|
|
236
|
+
# Compute the height and width and random offsets
|
|
237
|
+
h = round(scaling * target_h)
|
|
238
|
+
w = round(scaling * target_w)
|
|
239
|
+
|
|
240
|
+
# Correct for stack dimensions
|
|
241
|
+
if self.meta.iloc[r_stack].Width-w == 0:
|
|
242
|
+
x = 0
|
|
243
|
+
|
|
244
|
+
elif self.meta.iloc[r_stack].Width-w < 0:
|
|
245
|
+
return
|
|
246
|
+
|
|
247
|
+
else:
|
|
248
|
+
x = np.random.choice(self.meta.iloc[r_stack].Width-w)
|
|
249
|
+
|
|
250
|
+
# Correct for stack dimensions
|
|
251
|
+
if self.meta.iloc[r_stack].Height-h == 0:
|
|
252
|
+
y = 0
|
|
253
|
+
|
|
254
|
+
elif self.meta.iloc[r_stack].Height-h < 0:
|
|
255
|
+
return
|
|
256
|
+
|
|
257
|
+
else:
|
|
258
|
+
y = np.random.choice(self.meta.iloc[r_stack].Height-h)
|
|
259
|
+
|
|
260
|
+
## Select random plane + range
|
|
261
|
+
r_plane = np.random.choice(self.meta.iloc[r_stack].Depth-size[0]+1)
|
|
262
|
+
|
|
263
|
+
z_begin = r_plane
|
|
264
|
+
z_end = r_plane+size[0]
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
# Scale if neccessary to the correct dimensions
|
|
268
|
+
tmp_stack = self.data['stacks'][f'x{r_stack}'][z_begin:z_end, y:y+h, x:x+w]
|
|
269
|
+
tmp_dendrites = self.data['dendrites'][f'x{r_stack}'][z_begin:z_end, y:y+h, x:x+w]
|
|
270
|
+
tmp_spines = self.data['spines'][f'x{r_stack}'][z_begin:z_end, y:y+h, x:x+w]
|
|
271
|
+
|
|
272
|
+
# Data needs to be rescaled
|
|
273
|
+
if scaling != 1:
|
|
274
|
+
return_stack = []
|
|
275
|
+
return_dendrites = []
|
|
276
|
+
return_spines = []
|
|
277
|
+
|
|
278
|
+
# Do this for each plane
|
|
279
|
+
# and ensure that OpenCV is happy
|
|
280
|
+
for i in range(tmp_stack.shape[0]):
|
|
281
|
+
return_stack.append(cv2.resize(tmp_stack[i], (target_h, target_w)))
|
|
282
|
+
return_dendrites.append(cv2.resize(tmp_dendrites[i].astype(np.uint8), (target_h, target_w)).astype(bool))
|
|
283
|
+
return_spines.append(cv2.resize(tmp_spines[i].astype(np.uint8), (target_h, target_w)).astype(bool))
|
|
284
|
+
|
|
285
|
+
return_stack = np.asarray(return_stack)
|
|
286
|
+
return_dendrites = np.asarray(return_dendrites)
|
|
287
|
+
return_spines = np.asarray(return_spines)
|
|
288
|
+
|
|
289
|
+
else:
|
|
290
|
+
return_stack = tmp_stack
|
|
291
|
+
return_dendrites = tmp_dendrites
|
|
292
|
+
return_spines = tmp_spines
|
|
293
|
+
|
|
294
|
+
if squeeze:
|
|
295
|
+
# Return sample
|
|
296
|
+
return return_stack.squeeze(), return_dendrites.squeeze(), return_spines.squeeze()
|
|
297
|
+
|
|
298
|
+
else:
|
|
299
|
+
return return_stack, return_dendrites, return_spines
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: neuro-sam
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.8
|
|
4
4
|
Summary: Neuro-SAM: Foundation Models for Dendrite and Dendritic Spine Segmentation
|
|
5
5
|
Author-email: Nipun Arora <nipunarora8@yahoo.com>
|
|
6
6
|
License: MIT License
|
|
@@ -55,6 +55,10 @@ Requires-Dist: PyQt5
|
|
|
55
55
|
Requires-Dist: opencv-python-headless
|
|
56
56
|
Requires-Dist: matplotlib
|
|
57
57
|
Requires-Dist: requests
|
|
58
|
+
Requires-Dist: flammkuchen
|
|
59
|
+
Requires-Dist: albumentations
|
|
60
|
+
Requires-Dist: wandb
|
|
61
|
+
Requires-Dist: tensorflow
|
|
58
62
|
Dynamic: license-file
|
|
59
63
|
|
|
60
64
|
<div align="center">
|
|
@@ -44,7 +44,6 @@ src/neuro_sam/napari_utils/punet_widget.py
|
|
|
44
44
|
src/neuro_sam/napari_utils/segmentation_model.py
|
|
45
45
|
src/neuro_sam/napari_utils/segmentation_module.py
|
|
46
46
|
src/neuro_sam/napari_utils/visualization_module.py
|
|
47
|
-
src/neuro_sam/punet/__init__.py
|
|
48
47
|
src/neuro_sam/punet/deepd3_model.py
|
|
49
48
|
src/neuro_sam/punet/prob_unet_deepd3.py
|
|
50
49
|
src/neuro_sam/punet/prob_unet_with_tversky.py
|
|
@@ -52,6 +51,11 @@ src/neuro_sam/punet/punet_inference.py
|
|
|
52
51
|
src/neuro_sam/punet/run_inference.py
|
|
53
52
|
src/neuro_sam/punet/unet_blocks.py
|
|
54
53
|
src/neuro_sam/punet/utils.py
|
|
54
|
+
src/neuro_sam/training/__init__.py
|
|
55
|
+
src/neuro_sam/training/train_dendrites.py
|
|
56
|
+
src/neuro_sam/training/utils/__init__.py
|
|
57
|
+
src/neuro_sam/training/utils/prompt_generation_dendrites.py
|
|
58
|
+
src/neuro_sam/training/utils/stream_dendrites.py
|
|
55
59
|
src/sam2/__init__.py
|
|
56
60
|
src/sam2/automatic_mask_generator.py
|
|
57
61
|
src/sam2/benchmark.py
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/algorithm/waypointastar.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/cost/reciprocal_transonic.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/node/bidirectional_node.py
RENAMED
|
File without changes
|
|
File without changes
|
{neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/visualization/__init__.py
RENAMED
|
File without changes
|
{neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/visualization/flythrough.py
RENAMED
|
File without changes
|
{neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/visualization/flythrough_all.py
RENAMED
|
File without changes
|
{neuro_sam-0.1.7 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/visualization/tube_data.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|