neuro-sam 0.1.6__tar.gz → 0.1.8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {neuro_sam-0.1.6/src/neuro_sam.egg-info → neuro_sam-0.1.8}/PKG-INFO +6 -1
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/pyproject.toml +9 -2
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/napari_utils/path_tracing_module.py +4 -4
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/napari_utils/punet_widget.py +6 -4
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/napari_utils/segmentation_model.py +16 -7
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/napari_utils/segmentation_module.py +1 -4
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/plugin.py +5 -3
- neuro_sam-0.1.8/src/neuro_sam/training/__init__.py +1 -0
- neuro_sam-0.1.8/src/neuro_sam/training/train_dendrites.py +226 -0
- neuro_sam-0.1.8/src/neuro_sam/training/utils/__init__.py +1 -0
- neuro_sam-0.1.8/src/neuro_sam/training/utils/prompt_generation_dendrites.py +78 -0
- neuro_sam-0.1.8/src/neuro_sam/training/utils/stream_dendrites.py +299 -0
- neuro_sam-0.1.8/src/neuro_sam/utils.py +90 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8/src/neuro_sam.egg-info}/PKG-INFO +6 -1
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam.egg-info/SOURCES.txt +6 -1
- neuro_sam-0.1.8/src/neuro_sam.egg-info/entry_points.txt +7 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam.egg-info/requires.txt +5 -0
- neuro_sam-0.1.6/src/neuro_sam/punet/__init__.py +0 -0
- neuro_sam-0.1.6/src/neuro_sam.egg-info/entry_points.txt +0 -5
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/LICENSE +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/README.md +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/setup.cfg +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/__init__.py +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/__init__.py +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/algorithm/__init__.py +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/algorithm/astar.py +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/algorithm/waypointastar.py +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/algorithm/waypointastar_speedup.py +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/connected_componen.py +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/cost/__init__.py +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/cost/cost.py +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/cost/reciprocal.py +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/cost/reciprocal_transonic.py +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/heuristic/__init__.py +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/heuristic/euclidean.py +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/heuristic/heuristic.py +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/image/__init__.py +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/image/stats.py +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/input/__init__.py +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/input/inputs.py +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/node/__init__.py +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/node/bidirectional_node.py +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/node/node.py +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/visualization/__init__.py +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/visualization/flythrough.py +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/visualization/flythrough_all.py +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/visualization/tube_data.py +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/visualization/tube_flythrough.py +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/napari_utils/anisotropic_scaling.py +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/napari_utils/color_utils.py +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/napari_utils/contrasting_color_system.py +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/napari_utils/main_widget.py +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/napari_utils/visualization_module.py +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/punet/deepd3_model.py +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/punet/prob_unet_deepd3.py +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/punet/prob_unet_with_tversky.py +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/punet/punet_inference.py +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/punet/run_inference.py +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/punet/unet_blocks.py +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/punet/utils.py +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam.egg-info/dependency_links.txt +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam.egg-info/top_level.txt +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/__init__.py +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/automatic_mask_generator.py +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/benchmark.py +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/build_sam.py +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/configs/sam2/sam2_hiera_b+.yaml +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/configs/sam2/sam2_hiera_l.yaml +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/configs/sam2/sam2_hiera_s.yaml +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/configs/sam2/sam2_hiera_t.yaml +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/configs/sam2.1/sam2.1_hiera_l.yaml +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/configs/sam2.1/sam2.1_hiera_s.yaml +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/configs/sam2.1/sam2.1_hiera_t.yaml +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/configs/train.yaml +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/modeling/__init__.py +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/modeling/backbones/__init__.py +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/modeling/backbones/hieradet.py +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/modeling/backbones/image_encoder.py +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/modeling/backbones/utils.py +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/modeling/memory_attention.py +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/modeling/memory_encoder.py +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/modeling/position_encoding.py +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/modeling/sam/__init__.py +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/modeling/sam/mask_decoder.py +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/modeling/sam/prompt_encoder.py +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/modeling/sam/transformer.py +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/modeling/sam2_base.py +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/modeling/sam2_utils.py +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/sam2.1_hiera_b+.yaml +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/sam2.1_hiera_l.yaml +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/sam2.1_hiera_s.yaml +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/sam2.1_hiera_t.yaml +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/sam2_hiera_b+.yaml +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/sam2_hiera_l.yaml +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/sam2_hiera_s.yaml +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/sam2_hiera_t.yaml +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/sam2_image_predictor.py +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/sam2_video_predictor.py +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/sam2_video_predictor_legacy.py +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/utils/__init__.py +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/utils/amg.py +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/utils/misc.py +0 -0
- {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/utils/transforms.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: neuro-sam
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.8
|
|
4
4
|
Summary: Neuro-SAM: Foundation Models for Dendrite and Dendritic Spine Segmentation
|
|
5
5
|
Author-email: Nipun Arora <nipunarora8@yahoo.com>
|
|
6
6
|
License: MIT License
|
|
@@ -54,6 +54,11 @@ Requires-Dist: numba
|
|
|
54
54
|
Requires-Dist: PyQt5
|
|
55
55
|
Requires-Dist: opencv-python-headless
|
|
56
56
|
Requires-Dist: matplotlib
|
|
57
|
+
Requires-Dist: requests
|
|
58
|
+
Requires-Dist: flammkuchen
|
|
59
|
+
Requires-Dist: albumentations
|
|
60
|
+
Requires-Dist: wandb
|
|
61
|
+
Requires-Dist: tensorflow
|
|
57
62
|
Dynamic: license-file
|
|
58
63
|
|
|
59
64
|
<div align="center">
|
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "neuro-sam"
|
|
7
|
-
version = "0.1.
|
|
7
|
+
version = "0.1.8"
|
|
8
8
|
description = "Neuro-SAM: Foundation Models for Dendrite and Dendritic Spine Segmentation"
|
|
9
9
|
readme = "README.md"
|
|
10
10
|
authors = [
|
|
@@ -37,7 +37,12 @@ dependencies = [
|
|
|
37
37
|
"numba",
|
|
38
38
|
"PyQt5",
|
|
39
39
|
"opencv-python-headless",
|
|
40
|
-
"matplotlib"
|
|
40
|
+
"matplotlib",
|
|
41
|
+
"requests",
|
|
42
|
+
"flammkuchen",
|
|
43
|
+
"albumentations",
|
|
44
|
+
"wandb",
|
|
45
|
+
"tensorflow"
|
|
41
46
|
]
|
|
42
47
|
requires-python = ">=3.10"
|
|
43
48
|
|
|
@@ -55,6 +60,8 @@ include = ["neuro_sam*", "sam2*"]
|
|
|
55
60
|
|
|
56
61
|
[project.scripts]
|
|
57
62
|
neuro-sam = "neuro_sam.plugin:main"
|
|
63
|
+
neuro-sam-download = "neuro_sam.utils:download_all_models"
|
|
64
|
+
neuro-sam-train = "neuro_sam.training.train_dendrites:main"
|
|
58
65
|
|
|
59
66
|
[project.entry-points."napari.manifest"]
|
|
60
67
|
neuro-sam = "neuro_sam:napari.yaml"
|
|
@@ -509,10 +509,10 @@ class PathTracingWidget(QWidget):
|
|
|
509
509
|
self.scaler.original_spacing_xyz[0] / x_nm # X
|
|
510
510
|
])
|
|
511
511
|
|
|
512
|
-
self.scaling_status.setText(
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
)
|
|
512
|
+
# self.scaling_status.setText(
|
|
513
|
+
# f"Pending: X={x_nm:.1f}, Y={y_nm:.1f}, Z={z_nm:.1f} nm\n"
|
|
514
|
+
# f"Scale factors (Z,Y,X): {temp_scale_factors[0]:.3f}, {temp_scale_factors[1]:.3f}, {temp_scale_factors[2]:.3f}"
|
|
515
|
+
# )
|
|
516
516
|
|
|
517
517
|
def _apply_scaling(self):
|
|
518
518
|
"""Apply current scaling settings"""
|
|
@@ -21,6 +21,8 @@ from neuro_sam.punet.punet_inference import run_inference_volume
|
|
|
21
21
|
|
|
22
22
|
|
|
23
23
|
|
|
24
|
+
from neuro_sam.utils import get_weights_path
|
|
25
|
+
|
|
24
26
|
class PunetSpineSegmentationWidget(QWidget):
|
|
25
27
|
"""
|
|
26
28
|
Widget for spine segmentation using Probabilistic U-Net.
|
|
@@ -36,7 +38,7 @@ class PunetSpineSegmentationWidget(QWidget):
|
|
|
36
38
|
|
|
37
39
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
38
40
|
self.model = None
|
|
39
|
-
self.model_path = "
|
|
41
|
+
self.model_path = get_weights_path("punet_best.pth") # Auto-download default weights
|
|
40
42
|
|
|
41
43
|
# Connect custom progress signal
|
|
42
44
|
self.progress_signal.connect(self._on_worker_progress)
|
|
@@ -191,13 +193,13 @@ class PunetSpineSegmentationWidget(QWidget):
|
|
|
191
193
|
def _segmentation_worker(self, vol, params):
|
|
192
194
|
import traceback
|
|
193
195
|
try:
|
|
196
|
+
# Import the refactored inference function from the package
|
|
194
197
|
# Import the refactored inference function from the package
|
|
195
198
|
try:
|
|
196
199
|
from neuro_sam.punet.punet_inference import run_inference_volume
|
|
197
200
|
except ImportError:
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
from neuro_sam.punet_inference import run_inference_volume
|
|
201
|
+
# Fallback should not be needed with proper package execution
|
|
202
|
+
raise ImportError("Could not import run_inference_volume from neuro_sam.punet.punet_inference")
|
|
201
203
|
|
|
202
204
|
yield "Starting inference..."
|
|
203
205
|
|
|
@@ -8,22 +8,33 @@ from scipy.ndimage import label
|
|
|
8
8
|
from matplotlib.path import Path
|
|
9
9
|
|
|
10
10
|
|
|
11
|
+
from neuro_sam.utils import get_weights_path
|
|
12
|
+
|
|
11
13
|
class DendriteSegmenter:
|
|
12
14
|
"""Class for segmenting dendrites from 3D image volumes using SAM2 with overlapping patches"""
|
|
13
15
|
|
|
14
|
-
def __init__(self, model_path=
|
|
16
|
+
def __init__(self, model_path=None, config_path="sam2.1_hiera_s.yaml", weights_path=None, device="cuda"):
|
|
15
17
|
"""
|
|
16
18
|
Initialize the dendrite segmenter with overlapping patches.
|
|
17
19
|
|
|
18
20
|
Args:
|
|
19
|
-
model_path: Path to SAM2 model checkpoint
|
|
21
|
+
model_path: Path to SAM2 model checkpoint (auto-downloaded if None)
|
|
20
22
|
config_path: Path to model configuration
|
|
21
|
-
weights_path: Path to trained weights
|
|
23
|
+
weights_path: Path to trained weights (auto-downloaded if None)
|
|
22
24
|
device: Device to run the model on (cpu or cuda)
|
|
23
25
|
"""
|
|
24
|
-
|
|
26
|
+
if model_path is None:
|
|
27
|
+
self.model_path = get_weights_path("sam2.1_hiera_small.pt")
|
|
28
|
+
else:
|
|
29
|
+
self.model_path = model_path
|
|
30
|
+
|
|
25
31
|
self.config_path = config_path
|
|
26
|
-
|
|
32
|
+
|
|
33
|
+
if weights_path is None:
|
|
34
|
+
self.weights_path = get_weights_path("dendrite_model.torch")
|
|
35
|
+
else:
|
|
36
|
+
self.weights_path = weights_path
|
|
37
|
+
|
|
27
38
|
self.device = device
|
|
28
39
|
self.predictor = None
|
|
29
40
|
|
|
@@ -35,8 +46,6 @@ class DendriteSegmenter:
|
|
|
35
46
|
|
|
36
47
|
# Try importing first to catch import errors
|
|
37
48
|
try:
|
|
38
|
-
import sys
|
|
39
|
-
sys.path.append('./Train-SAMv2')
|
|
40
49
|
from sam2.build_sam import build_sam2
|
|
41
50
|
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
|
42
51
|
print("Successfully imported SAM2 modules")
|
|
@@ -349,11 +349,8 @@ class SegmentationWidget(QWidget):
|
|
|
349
349
|
# Initialize segmenter if not already done
|
|
350
350
|
if self.segmenter is None:
|
|
351
351
|
self.segmenter = DendriteSegmenter(
|
|
352
|
-
model_path="./Train-SAMv2/checkpoints/sam2.1_hiera_small.pt",
|
|
353
|
-
config_path="sam2.1_hiera_s.yaml",
|
|
354
|
-
weights_path="./Train-SAMv2/results/samv2_dendrite/dendrite_model.torch",
|
|
355
352
|
device=device
|
|
356
|
-
)
|
|
353
|
+
) # Paths are now handled automatically by default args
|
|
357
354
|
|
|
358
355
|
# Load the model
|
|
359
356
|
success = self.segmenter.load_model()
|
|
@@ -247,11 +247,13 @@ def main():
|
|
|
247
247
|
else:
|
|
248
248
|
# Try to load a default benchmark image
|
|
249
249
|
try:
|
|
250
|
-
|
|
251
|
-
|
|
250
|
+
from neuro_sam.utils import get_weights_path
|
|
251
|
+
default_path = get_weights_path('DeepD3_Benchmark.tif')
|
|
252
|
+
print(f"No image path provided, loading default: {default_path}")
|
|
252
253
|
spacing_xyz = (args.x_spacing, args.y_spacing, args.z_spacing)
|
|
253
254
|
viewer = run_neuro_sam(image_path=default_path, spacing_xyz=spacing_xyz)
|
|
254
|
-
except
|
|
255
|
+
except Exception as e:
|
|
256
|
+
print(f"Failed to load default image: {e}")
|
|
255
257
|
sys.exit(1)
|
|
256
258
|
|
|
257
259
|
print("\nStarted NeuroSAM with anisotropic scaling support!")
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# Init file for training module
|
|
@@ -0,0 +1,226 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import torch
|
|
3
|
+
import os
|
|
4
|
+
import wandb
|
|
5
|
+
from torch.onnx.symbolic_opset11 import hstack
|
|
6
|
+
import torch.nn.functional as F
|
|
7
|
+
from .utils.stream_dendrites import DataGeneratorStream
|
|
8
|
+
from sam2.build_sam import build_sam2
|
|
9
|
+
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
|
10
|
+
import time
|
|
11
|
+
import argparse
|
|
12
|
+
from neuro_sam.utils import get_weights_path
|
|
13
|
+
|
|
14
|
+
def main():
|
|
15
|
+
parser = argparse.ArgumentParser(description="Train Neuro-SAM Dendrite Segmenter")
|
|
16
|
+
parser.add_argument("--ppn", type=int, required=True, help="Positive Points Number")
|
|
17
|
+
parser.add_argument("--pnn", type=int, required=True, help="Negative Points Number")
|
|
18
|
+
parser.add_argument("--model_name", type=str, required=True, choices=['small', 'base_plus', 'large', 'tiny'], help="SAM2 Model Size")
|
|
19
|
+
parser.add_argument("--batch_size", type=int, required=True, help="Batch Size")
|
|
20
|
+
parser.add_argument("--logger", type=str, default="False", help="Use WandB Logger (True/False)")
|
|
21
|
+
parser.add_argument("--data_dir", type=str, default="./data", help="Directory containing .d3set data files")
|
|
22
|
+
args = parser.parse_args()
|
|
23
|
+
|
|
24
|
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
25
|
+
print(f"Using device: {device}")
|
|
26
|
+
|
|
27
|
+
TRAINING_DATA_PATH = os.path.join(args.data_dir, "DeepD3_Training.d3set")
|
|
28
|
+
VALIDATION_DATA_PATH = os.path.join(args.data_dir, "DeepD3_Validation.d3set")
|
|
29
|
+
|
|
30
|
+
if not os.path.exists(TRAINING_DATA_PATH):
|
|
31
|
+
raise FileNotFoundError(f"Training data not found at {TRAINING_DATA_PATH}")
|
|
32
|
+
|
|
33
|
+
positive_points = args.ppn
|
|
34
|
+
negative_points = args.pnn
|
|
35
|
+
batch_size = args.batch_size
|
|
36
|
+
logger = (args.logger.lower() == "true")
|
|
37
|
+
|
|
38
|
+
print("Initializing Data Generator...")
|
|
39
|
+
dg_training = DataGeneratorStream(TRAINING_DATA_PATH,
|
|
40
|
+
batch_size=batch_size, # Data processed at once, depends on your GPU
|
|
41
|
+
target_resolution=0.094, # fixed to 94 nm, can be None for mixed resolution training
|
|
42
|
+
min_content=100,
|
|
43
|
+
resizing_size = 128,
|
|
44
|
+
positive_points = positive_points,
|
|
45
|
+
negative_points = negative_points) # images need to have at least 50 segmented px
|
|
46
|
+
|
|
47
|
+
dg_validation = DataGeneratorStream(VALIDATION_DATA_PATH,
|
|
48
|
+
batch_size=batch_size,
|
|
49
|
+
target_resolution=0.094,
|
|
50
|
+
min_content=100,
|
|
51
|
+
augment=False,
|
|
52
|
+
shuffle=False,
|
|
53
|
+
resizing_size = 128,
|
|
54
|
+
positive_points = positive_points,
|
|
55
|
+
negative_points = negative_points)
|
|
56
|
+
|
|
57
|
+
# Load model
|
|
58
|
+
model_name_map = {
|
|
59
|
+
'large': ("sam2.1_hiera_large.pt", "sam2.1_hiera_l.yaml"),
|
|
60
|
+
'base_plus': ("sam2.1_hiera_base_plus.pt", "sam2.1_hiera_b+.yaml"),
|
|
61
|
+
'small': ("sam2.1_hiera_small.pt", "sam2.1_hiera_s.yaml"),
|
|
62
|
+
'tiny': ("sam2.1_hiera_tiny.pt", "sam2.1_hiera_t.yaml"),
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
ckpt_name, model_cfg = model_name_map[args.model_name]
|
|
66
|
+
|
|
67
|
+
# Try to find weights via util or local defaults
|
|
68
|
+
try:
|
|
69
|
+
sam2_checkpoint = get_weights_path(ckpt_name)
|
|
70
|
+
except:
|
|
71
|
+
# Fallback if not downloadable or found
|
|
72
|
+
sam2_checkpoint = ckpt_name
|
|
73
|
+
print(f"Warning: Could not resolve weight path for {ckpt_name}, assuming local file.")
|
|
74
|
+
|
|
75
|
+
print(f"Loading SAM2 model: {args.model_name} from {sam2_checkpoint}")
|
|
76
|
+
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device) # load model
|
|
77
|
+
|
|
78
|
+
predictor = SAM2ImagePredictor(sam2_model)
|
|
79
|
+
|
|
80
|
+
# Set training parameters
|
|
81
|
+
|
|
82
|
+
predictor.model.sam_mask_decoder.train(True) # enable training of mask decoder
|
|
83
|
+
predictor.model.sam_prompt_encoder.train(True) # enable training of prompt encoder
|
|
84
|
+
predictor.model.image_encoder.train(True) # enable training of image encoder: For this to work you need to scan the code for "no_grad" and remove them all
|
|
85
|
+
optimizer=torch.optim.AdamW(params=predictor.model.parameters(),lr=1e-5,weight_decay=4e-5)
|
|
86
|
+
scaler = torch.cuda.amp.GradScaler() # mixed precision
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
time_str = "-".join(["{:0>2}".format(x) for x in time.localtime(time.time())][:-3])
|
|
90
|
+
ckpt_path = f'results/samv2_{args.model_name}_{time_str}'
|
|
91
|
+
if not os.path.exists(ckpt_path): os.makedirs(ckpt_path)
|
|
92
|
+
|
|
93
|
+
if logger:
|
|
94
|
+
wandb.init(
|
|
95
|
+
# set the wandb project where this run will be logged
|
|
96
|
+
project="N-SAMv2",
|
|
97
|
+
|
|
98
|
+
# track hyperparameters and run metadata
|
|
99
|
+
config={
|
|
100
|
+
"architecture": "SAMv2",
|
|
101
|
+
"dataset": "DeepD3",
|
|
102
|
+
"model": args.model_name,
|
|
103
|
+
"epochs": 100000,
|
|
104
|
+
"ckpt_path": ckpt_path,
|
|
105
|
+
"image_size": (1,128,128),
|
|
106
|
+
"min_content": 100,
|
|
107
|
+
"positive_points": positive_points,
|
|
108
|
+
"negative_points": negative_points,
|
|
109
|
+
"batch_size": batch_size,
|
|
110
|
+
"prompt_seed":42
|
|
111
|
+
}
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
# add val code here
|
|
116
|
+
|
|
117
|
+
def perform_validation(predictor):
|
|
118
|
+
print('Performing Validation')
|
|
119
|
+
mean_iou = []
|
|
120
|
+
mean_dice = []
|
|
121
|
+
mean_loss = []
|
|
122
|
+
with torch.no_grad():
|
|
123
|
+
for i in range(20):
|
|
124
|
+
try:
|
|
125
|
+
n = np.random.randint(len(dg_validation))
|
|
126
|
+
image,mask,input_point, input_label = dg_validation[n]
|
|
127
|
+
|
|
128
|
+
except:
|
|
129
|
+
print('Error in validation batch generation')
|
|
130
|
+
continue
|
|
131
|
+
if mask.shape[0]==0: continue # ignore empty batches
|
|
132
|
+
predictor.set_image_batch(image) # apply SAM image encoder to the image
|
|
133
|
+
|
|
134
|
+
mask_input, unnorm_coords, labels, unnorm_box = predictor._prep_prompts(input_point, input_label, box=None, mask_logits=None, normalize_coords=True)
|
|
135
|
+
sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(points=(unnorm_coords, labels),boxes=None,masks=None,)
|
|
136
|
+
|
|
137
|
+
# mask decoder
|
|
138
|
+
|
|
139
|
+
high_res_features = [feat_level[-1].unsqueeze(0) for feat_level in predictor._features["high_res_feats"]]
|
|
140
|
+
low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder(image_embeddings=predictor._features["image_embed"],image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(),sparse_prompt_embeddings=sparse_embeddings,dense_prompt_embeddings=dense_embeddings,multimask_output=True,repeat_image=False,high_res_features=high_res_features,)
|
|
141
|
+
prd_masks = predictor._transforms.postprocess_masks(low_res_masks, predictor._orig_hw[-1])# Upscale the masks to the original image resolution
|
|
142
|
+
|
|
143
|
+
# Segmentaion Loss caclulation
|
|
144
|
+
|
|
145
|
+
gt_mask = torch.tensor(mask.astype(np.float32)).to(device)
|
|
146
|
+
prd_mask = torch.sigmoid(prd_masks[:, 0])# Turn logit map to probability map
|
|
147
|
+
seg_loss = (-gt_mask * torch.log(prd_mask + 0.00001) - (1 - gt_mask) * torch.log((1 - prd_mask) + 0.00001)).mean() # cross entropy loss
|
|
148
|
+
|
|
149
|
+
# Score loss calculation (intersection over union) IOU
|
|
150
|
+
|
|
151
|
+
inter = (gt_mask * (prd_mask > 0.5)).sum(1).sum(1)
|
|
152
|
+
iou = inter / (gt_mask.sum(1).sum(1) + (prd_mask > 0.5).sum(1).sum(1) - inter)
|
|
153
|
+
total_sum = gt_mask.sum(1).sum(1) + (prd_mask > 0.5).sum(1).sum(1)
|
|
154
|
+
dicee = ((2 * inter) / total_sum).mean()
|
|
155
|
+
score_loss = torch.abs(prd_scores[:, 0] - iou).mean()
|
|
156
|
+
loss=seg_loss+score_loss*0.05 # mix losses
|
|
157
|
+
|
|
158
|
+
mean_iou.append(iou.cpu().detach().numpy())
|
|
159
|
+
mean_dice.append(dicee.cpu().detach().numpy())
|
|
160
|
+
mean_loss.append(loss.cpu().detach().numpy())
|
|
161
|
+
|
|
162
|
+
return np.array(mean_iou).mean(), np.array(mean_dice).mean(), np.array(mean_loss).mean()
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
for itr in range(100000):
|
|
166
|
+
epoch_dice = epoch_iou = epoch_loss = []
|
|
167
|
+
with torch.cuda.amp.autocast(): # cast to mix precision
|
|
168
|
+
n = np.random.randint(len(dg_training))
|
|
169
|
+
try:
|
|
170
|
+
image,mask,input_point, input_label = dg_training[n]
|
|
171
|
+
except:
|
|
172
|
+
print('Error in training batch')
|
|
173
|
+
continue
|
|
174
|
+
if mask.shape[0]==0: continue # ignore empty batches
|
|
175
|
+
predictor.set_image_batch(image) # apply SAM image encoder to the image
|
|
176
|
+
|
|
177
|
+
# prompt encoding
|
|
178
|
+
|
|
179
|
+
mask_input, unnorm_coords, labels, unnorm_box = predictor._prep_prompts(input_point, input_label, box=None, mask_logits=None, normalize_coords=True)
|
|
180
|
+
sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(points=(unnorm_coords, labels),boxes=None,masks=None,)
|
|
181
|
+
|
|
182
|
+
# mask decoder
|
|
183
|
+
|
|
184
|
+
high_res_features = [feat_level[-1].unsqueeze(0) for feat_level in predictor._features["high_res_feats"]]
|
|
185
|
+
low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder(image_embeddings=predictor._features["image_embed"],image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(),sparse_prompt_embeddings=sparse_embeddings,dense_prompt_embeddings=dense_embeddings,multimask_output=True,repeat_image=False,high_res_features=high_res_features,)
|
|
186
|
+
prd_masks = predictor._transforms.postprocess_masks(low_res_masks, predictor._orig_hw[-1])# Upscale the masks to the original image resolution
|
|
187
|
+
|
|
188
|
+
gt_mask = torch.tensor(mask.astype(np.float32)).to(device)
|
|
189
|
+
prd_mask = torch.sigmoid(prd_masks[:, 0])# Turn logit map to probability map
|
|
190
|
+
seg_loss = (-gt_mask * torch.log(prd_mask + 0.00001) - (1 - gt_mask) * torch.log((1 - prd_mask) + 0.00001)).mean() # cross entropy loss
|
|
191
|
+
|
|
192
|
+
# Score loss calculation (intersection over union) IOU
|
|
193
|
+
|
|
194
|
+
inter = (gt_mask * (prd_mask > 0.5)).sum(1).sum(1)
|
|
195
|
+
iou = inter / (gt_mask.sum(1).sum(1) + (prd_mask > 0.5).sum(1).sum(1) - inter)
|
|
196
|
+
total_sum = gt_mask.sum(1).sum(1) + (prd_mask > 0.5).sum(1).sum(1)
|
|
197
|
+
dicee = ((2 * inter) / total_sum).mean()
|
|
198
|
+
score_loss = torch.abs(prd_scores[:, 0] - iou).mean()
|
|
199
|
+
loss=seg_loss+score_loss*0.05 # mix losses
|
|
200
|
+
|
|
201
|
+
predictor.model.zero_grad() # empty gradient
|
|
202
|
+
scaler.scale(loss).backward() # Backpropogate
|
|
203
|
+
scaler.step(optimizer)
|
|
204
|
+
scaler.update() # Mix precision
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
# Display results
|
|
208
|
+
if itr==0: mean_iou=0
|
|
209
|
+
mean_iou = mean_iou * 0.99 + 0.01 * np.mean(iou.cpu().detach().numpy())
|
|
210
|
+
if logger:
|
|
211
|
+
wandb.log({'step': itr, 'loss':loss, 'iou':mean_iou, 'dice_score':dicee})
|
|
212
|
+
print(f'step: {itr}, loss: {loss}, iou: {mean_iou}, dice_score: {dicee}')
|
|
213
|
+
|
|
214
|
+
if itr%500 == 0:
|
|
215
|
+
|
|
216
|
+
val_iou, val_dice, val_loss = perform_validation(predictor)
|
|
217
|
+
if logger:
|
|
218
|
+
wandb.log({'step': itr, 'val_loss':val_loss, 'val_iou':val_iou, 'val_dice_score':val_dice})
|
|
219
|
+
print(f'step: {itr}, val_loss: {val_loss}, val_iou: {val_iou}, val_dice_score: {val_dice}')
|
|
220
|
+
|
|
221
|
+
torch.save(predictor.model.state_dict(), f"{ckpt_path}/model_{itr}.torch") # save model
|
|
222
|
+
if logger:
|
|
223
|
+
wandb.finish()
|
|
224
|
+
|
|
225
|
+
if __name__ == "__main__":
|
|
226
|
+
main()
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# Init file for training utils
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from scipy import ndimage
|
|
3
|
+
import cv2, random
|
|
4
|
+
from collections import *
|
|
5
|
+
from itertools import *
|
|
6
|
+
from functools import *
|
|
7
|
+
|
|
8
|
+
class PromptGeneration:
|
|
9
|
+
def __init__(self, random_seed=0, neg_range=(3, 20), min_content = 20):
|
|
10
|
+
if random_seed:
|
|
11
|
+
random.seed(random_seed)
|
|
12
|
+
np.random.seed(random_seed)
|
|
13
|
+
|
|
14
|
+
self.neg_range = neg_range
|
|
15
|
+
|
|
16
|
+
self.min_area = min_content
|
|
17
|
+
|
|
18
|
+
def get_labelmap(self, label):
|
|
19
|
+
structure = ndimage.generate_binary_structure(2, 2)
|
|
20
|
+
labelmaps, connected_num = ndimage.label(label, structure=structure)
|
|
21
|
+
|
|
22
|
+
label = np.zeros_like(labelmaps)
|
|
23
|
+
for i in range(1, 1+connected_num):
|
|
24
|
+
if np.sum(labelmaps==i) >= self.min_area: label += np.where(labelmaps==i, 255, 0)
|
|
25
|
+
|
|
26
|
+
structure = ndimage.generate_binary_structure(2, 2)
|
|
27
|
+
labelmaps, connected_num = ndimage.label(label, structure=structure)
|
|
28
|
+
|
|
29
|
+
return labelmaps, connected_num
|
|
30
|
+
|
|
31
|
+
def search_negative_region_numpy(self, labelmap):
|
|
32
|
+
inner_range, outer_range = self.neg_range
|
|
33
|
+
def search(neg_range):
|
|
34
|
+
kernel = np.ones((neg_range * 2 + 1, neg_range * 2 + 1), np.uint8)
|
|
35
|
+
negative_region = cv2.dilate(labelmap, kernel, iterations=1)
|
|
36
|
+
mx = labelmap.max() + 1
|
|
37
|
+
labelmap_r = (mx - labelmap) * np.minimum(1, labelmap)
|
|
38
|
+
r = cv2.dilate(labelmap_r, kernel, iterations=1)
|
|
39
|
+
negative_region_r = (r.astype(np.int32) - mx) * np.minimum(1, r)
|
|
40
|
+
diff = negative_region.astype(np.int32) + negative_region_r
|
|
41
|
+
overlap = np.minimum(1, np.abs(diff).astype(np.uint8))
|
|
42
|
+
return negative_region - overlap - labelmap
|
|
43
|
+
return search(outer_range) - search(inner_range)
|
|
44
|
+
|
|
45
|
+
def get_prompt_points(self, label_mask, ppp, ppn):
|
|
46
|
+
label_mask_cp = np.copy(label_mask)
|
|
47
|
+
label_mask_cp[label_mask_cp >= 1] = 1
|
|
48
|
+
labelmaps, connected_num = self.get_labelmap(label_mask_cp)
|
|
49
|
+
|
|
50
|
+
coord_positive, coord_negative = [], []
|
|
51
|
+
|
|
52
|
+
connected_components = list(range(1, connected_num+1))
|
|
53
|
+
random.shuffle(connected_components)
|
|
54
|
+
|
|
55
|
+
for i in connected_components:
|
|
56
|
+
cc = np.copy(labelmaps)
|
|
57
|
+
cc[cc!=i] = 0
|
|
58
|
+
cc[cc==i] = 1
|
|
59
|
+
if ppp:
|
|
60
|
+
coord_positive.append(random.choice([[y, x] for x, y in np.argwhere(cc == 1)]))
|
|
61
|
+
ppp -= 1
|
|
62
|
+
|
|
63
|
+
random.shuffle(connected_components)
|
|
64
|
+
for i in connected_components:
|
|
65
|
+
cc = np.copy(labelmaps)
|
|
66
|
+
cc[cc!=i] = 0
|
|
67
|
+
cc[cc==i] = 1
|
|
68
|
+
negative_region = self.search_negative_region_numpy(cc.astype(np.uint8))
|
|
69
|
+
negative_region = negative_region * (1 - label_mask_cp)
|
|
70
|
+
if ppn:
|
|
71
|
+
coord_negative.append(random.choice([[y, x] for x, y in np.argwhere(negative_region == 1)]))
|
|
72
|
+
ppn -= 1
|
|
73
|
+
|
|
74
|
+
negative_region = self.search_negative_region_numpy(label_mask_cp)
|
|
75
|
+
|
|
76
|
+
if ppp: coord_positive += random.sample([[y, x] for x, y in np.argwhere(label_mask_cp == 1)], ppp)
|
|
77
|
+
if ppn: coord_negative += random.sample([[y, x] for x, y in np.argwhere(negative_region == 1)], ppn)
|
|
78
|
+
return coord_positive, coord_negative
|