neuro-sam 0.1.6__tar.gz → 0.1.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. {neuro_sam-0.1.6/src/neuro_sam.egg-info → neuro_sam-0.1.8}/PKG-INFO +6 -1
  2. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/pyproject.toml +9 -2
  3. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/napari_utils/path_tracing_module.py +4 -4
  4. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/napari_utils/punet_widget.py +6 -4
  5. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/napari_utils/segmentation_model.py +16 -7
  6. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/napari_utils/segmentation_module.py +1 -4
  7. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/plugin.py +5 -3
  8. neuro_sam-0.1.8/src/neuro_sam/training/__init__.py +1 -0
  9. neuro_sam-0.1.8/src/neuro_sam/training/train_dendrites.py +226 -0
  10. neuro_sam-0.1.8/src/neuro_sam/training/utils/__init__.py +1 -0
  11. neuro_sam-0.1.8/src/neuro_sam/training/utils/prompt_generation_dendrites.py +78 -0
  12. neuro_sam-0.1.8/src/neuro_sam/training/utils/stream_dendrites.py +299 -0
  13. neuro_sam-0.1.8/src/neuro_sam/utils.py +90 -0
  14. {neuro_sam-0.1.6 → neuro_sam-0.1.8/src/neuro_sam.egg-info}/PKG-INFO +6 -1
  15. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam.egg-info/SOURCES.txt +6 -1
  16. neuro_sam-0.1.8/src/neuro_sam.egg-info/entry_points.txt +7 -0
  17. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam.egg-info/requires.txt +5 -0
  18. neuro_sam-0.1.6/src/neuro_sam/punet/__init__.py +0 -0
  19. neuro_sam-0.1.6/src/neuro_sam.egg-info/entry_points.txt +0 -5
  20. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/LICENSE +0 -0
  21. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/README.md +0 -0
  22. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/setup.cfg +0 -0
  23. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/__init__.py +0 -0
  24. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/__init__.py +0 -0
  25. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/algorithm/__init__.py +0 -0
  26. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/algorithm/astar.py +0 -0
  27. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/algorithm/waypointastar.py +0 -0
  28. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/algorithm/waypointastar_speedup.py +0 -0
  29. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/connected_componen.py +0 -0
  30. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/cost/__init__.py +0 -0
  31. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/cost/cost.py +0 -0
  32. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/cost/reciprocal.py +0 -0
  33. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/cost/reciprocal_transonic.py +0 -0
  34. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/heuristic/__init__.py +0 -0
  35. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/heuristic/euclidean.py +0 -0
  36. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/heuristic/heuristic.py +0 -0
  37. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/image/__init__.py +0 -0
  38. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/image/stats.py +0 -0
  39. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/input/__init__.py +0 -0
  40. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/input/inputs.py +0 -0
  41. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/node/__init__.py +0 -0
  42. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/node/bidirectional_node.py +0 -0
  43. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/node/node.py +0 -0
  44. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/visualization/__init__.py +0 -0
  45. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/visualization/flythrough.py +0 -0
  46. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/visualization/flythrough_all.py +0 -0
  47. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/visualization/tube_data.py +0 -0
  48. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/brightest_path_lib/visualization/tube_flythrough.py +0 -0
  49. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/napari_utils/anisotropic_scaling.py +0 -0
  50. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/napari_utils/color_utils.py +0 -0
  51. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/napari_utils/contrasting_color_system.py +0 -0
  52. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/napari_utils/main_widget.py +0 -0
  53. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/napari_utils/visualization_module.py +0 -0
  54. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/punet/deepd3_model.py +0 -0
  55. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/punet/prob_unet_deepd3.py +0 -0
  56. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/punet/prob_unet_with_tversky.py +0 -0
  57. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/punet/punet_inference.py +0 -0
  58. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/punet/run_inference.py +0 -0
  59. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/punet/unet_blocks.py +0 -0
  60. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam/punet/utils.py +0 -0
  61. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam.egg-info/dependency_links.txt +0 -0
  62. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/neuro_sam.egg-info/top_level.txt +0 -0
  63. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/__init__.py +0 -0
  64. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/automatic_mask_generator.py +0 -0
  65. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/benchmark.py +0 -0
  66. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/build_sam.py +0 -0
  67. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/configs/sam2/sam2_hiera_b+.yaml +0 -0
  68. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/configs/sam2/sam2_hiera_l.yaml +0 -0
  69. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/configs/sam2/sam2_hiera_s.yaml +0 -0
  70. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/configs/sam2/sam2_hiera_t.yaml +0 -0
  71. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +0 -0
  72. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/configs/sam2.1/sam2.1_hiera_l.yaml +0 -0
  73. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/configs/sam2.1/sam2.1_hiera_s.yaml +0 -0
  74. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/configs/sam2.1/sam2.1_hiera_t.yaml +0 -0
  75. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +0 -0
  76. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/configs/train.yaml +0 -0
  77. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/modeling/__init__.py +0 -0
  78. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/modeling/backbones/__init__.py +0 -0
  79. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/modeling/backbones/hieradet.py +0 -0
  80. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/modeling/backbones/image_encoder.py +0 -0
  81. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/modeling/backbones/utils.py +0 -0
  82. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/modeling/memory_attention.py +0 -0
  83. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/modeling/memory_encoder.py +0 -0
  84. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/modeling/position_encoding.py +0 -0
  85. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/modeling/sam/__init__.py +0 -0
  86. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/modeling/sam/mask_decoder.py +0 -0
  87. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/modeling/sam/prompt_encoder.py +0 -0
  88. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/modeling/sam/transformer.py +0 -0
  89. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/modeling/sam2_base.py +0 -0
  90. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/modeling/sam2_utils.py +0 -0
  91. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/sam2.1_hiera_b+.yaml +0 -0
  92. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/sam2.1_hiera_l.yaml +0 -0
  93. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/sam2.1_hiera_s.yaml +0 -0
  94. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/sam2.1_hiera_t.yaml +0 -0
  95. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/sam2_hiera_b+.yaml +0 -0
  96. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/sam2_hiera_l.yaml +0 -0
  97. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/sam2_hiera_s.yaml +0 -0
  98. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/sam2_hiera_t.yaml +0 -0
  99. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/sam2_image_predictor.py +0 -0
  100. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/sam2_video_predictor.py +0 -0
  101. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/sam2_video_predictor_legacy.py +0 -0
  102. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/utils/__init__.py +0 -0
  103. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/utils/amg.py +0 -0
  104. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/utils/misc.py +0 -0
  105. {neuro_sam-0.1.6 → neuro_sam-0.1.8}/src/sam2/utils/transforms.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: neuro-sam
3
- Version: 0.1.6
3
+ Version: 0.1.8
4
4
  Summary: Neuro-SAM: Foundation Models for Dendrite and Dendritic Spine Segmentation
5
5
  Author-email: Nipun Arora <nipunarora8@yahoo.com>
6
6
  License: MIT License
@@ -54,6 +54,11 @@ Requires-Dist: numba
54
54
  Requires-Dist: PyQt5
55
55
  Requires-Dist: opencv-python-headless
56
56
  Requires-Dist: matplotlib
57
+ Requires-Dist: requests
58
+ Requires-Dist: flammkuchen
59
+ Requires-Dist: albumentations
60
+ Requires-Dist: wandb
61
+ Requires-Dist: tensorflow
57
62
  Dynamic: license-file
58
63
 
59
64
  <div align="center">
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "neuro-sam"
7
- version = "0.1.6"
7
+ version = "0.1.8"
8
8
  description = "Neuro-SAM: Foundation Models for Dendrite and Dendritic Spine Segmentation"
9
9
  readme = "README.md"
10
10
  authors = [
@@ -37,7 +37,12 @@ dependencies = [
37
37
  "numba",
38
38
  "PyQt5",
39
39
  "opencv-python-headless",
40
- "matplotlib"
40
+ "matplotlib",
41
+ "requests",
42
+ "flammkuchen",
43
+ "albumentations",
44
+ "wandb",
45
+ "tensorflow"
41
46
  ]
42
47
  requires-python = ">=3.10"
43
48
 
@@ -55,6 +60,8 @@ include = ["neuro_sam*", "sam2*"]
55
60
 
56
61
  [project.scripts]
57
62
  neuro-sam = "neuro_sam.plugin:main"
63
+ neuro-sam-download = "neuro_sam.utils:download_all_models"
64
+ neuro-sam-train = "neuro_sam.training.train_dendrites:main"
58
65
 
59
66
  [project.entry-points."napari.manifest"]
60
67
  neuro-sam = "neuro_sam:napari.yaml"
@@ -509,10 +509,10 @@ class PathTracingWidget(QWidget):
509
509
  self.scaler.original_spacing_xyz[0] / x_nm # X
510
510
  ])
511
511
 
512
- self.scaling_status.setText(
513
- f"Pending: X={x_nm:.1f}, Y={y_nm:.1f}, Z={z_nm:.1f} nm\n"
514
- f"Scale factors (Z,Y,X): {temp_scale_factors[0]:.3f}, {temp_scale_factors[1]:.3f}, {temp_scale_factors[2]:.3f}"
515
- )
512
+ # self.scaling_status.setText(
513
+ # f"Pending: X={x_nm:.1f}, Y={y_nm:.1f}, Z={z_nm:.1f} nm\n"
514
+ # f"Scale factors (Z,Y,X): {temp_scale_factors[0]:.3f}, {temp_scale_factors[1]:.3f}, {temp_scale_factors[2]:.3f}"
515
+ # )
516
516
 
517
517
  def _apply_scaling(self):
518
518
  """Apply current scaling settings"""
@@ -21,6 +21,8 @@ from neuro_sam.punet.punet_inference import run_inference_volume
21
21
 
22
22
 
23
23
 
24
+ from neuro_sam.utils import get_weights_path
25
+
24
26
  class PunetSpineSegmentationWidget(QWidget):
25
27
  """
26
28
  Widget for spine segmentation using Probabilistic U-Net.
@@ -36,7 +38,7 @@ class PunetSpineSegmentationWidget(QWidget):
36
38
 
37
39
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
40
  self.model = None
39
- self.model_path = "punet/punet_best.pth" # Default relative path
41
+ self.model_path = get_weights_path("punet_best.pth") # Auto-download default weights
40
42
 
41
43
  # Connect custom progress signal
42
44
  self.progress_signal.connect(self._on_worker_progress)
@@ -191,13 +193,13 @@ class PunetSpineSegmentationWidget(QWidget):
191
193
  def _segmentation_worker(self, vol, params):
192
194
  import traceback
193
195
  try:
196
+ # Import the refactored inference function from the package
194
197
  # Import the refactored inference function from the package
195
198
  try:
196
199
  from neuro_sam.punet.punet_inference import run_inference_volume
197
200
  except ImportError:
198
- import sys
199
- sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'punet'))
200
- from neuro_sam.punet_inference import run_inference_volume
201
+ # Fallback should not be needed with proper package execution
202
+ raise ImportError("Could not import run_inference_volume from neuro_sam.punet.punet_inference")
201
203
 
202
204
  yield "Starting inference..."
203
205
 
@@ -8,22 +8,33 @@ from scipy.ndimage import label
8
8
  from matplotlib.path import Path
9
9
 
10
10
 
11
+ from neuro_sam.utils import get_weights_path
12
+
11
13
  class DendriteSegmenter:
12
14
  """Class for segmenting dendrites from 3D image volumes using SAM2 with overlapping patches"""
13
15
 
14
- def __init__(self, model_path="./Train-SAMv2/checkpoints/sam2.1_hiera_small.pt", config_path="sam2.1_hiera_s.yaml", weights_path="./Train-SAMv2/results/samv2_dendrite/dendrite_model.torch", device="cuda"):
16
+ def __init__(self, model_path=None, config_path="sam2.1_hiera_s.yaml", weights_path=None, device="cuda"):
15
17
  """
16
18
  Initialize the dendrite segmenter with overlapping patches.
17
19
 
18
20
  Args:
19
- model_path: Path to SAM2 model checkpoint
21
+ model_path: Path to SAM2 model checkpoint (auto-downloaded if None)
20
22
  config_path: Path to model configuration
21
- weights_path: Path to trained weights
23
+ weights_path: Path to trained weights (auto-downloaded if None)
22
24
  device: Device to run the model on (cpu or cuda)
23
25
  """
24
- self.model_path = model_path
26
+ if model_path is None:
27
+ self.model_path = get_weights_path("sam2.1_hiera_small.pt")
28
+ else:
29
+ self.model_path = model_path
30
+
25
31
  self.config_path = config_path
26
- self.weights_path = weights_path
32
+
33
+ if weights_path is None:
34
+ self.weights_path = get_weights_path("dendrite_model.torch")
35
+ else:
36
+ self.weights_path = weights_path
37
+
27
38
  self.device = device
28
39
  self.predictor = None
29
40
 
@@ -35,8 +46,6 @@ class DendriteSegmenter:
35
46
 
36
47
  # Try importing first to catch import errors
37
48
  try:
38
- import sys
39
- sys.path.append('./Train-SAMv2')
40
49
  from sam2.build_sam import build_sam2
41
50
  from sam2.sam2_image_predictor import SAM2ImagePredictor
42
51
  print("Successfully imported SAM2 modules")
@@ -349,11 +349,8 @@ class SegmentationWidget(QWidget):
349
349
  # Initialize segmenter if not already done
350
350
  if self.segmenter is None:
351
351
  self.segmenter = DendriteSegmenter(
352
- model_path="./Train-SAMv2/checkpoints/sam2.1_hiera_small.pt",
353
- config_path="sam2.1_hiera_s.yaml",
354
- weights_path="./Train-SAMv2/results/samv2_dendrite/dendrite_model.torch",
355
352
  device=device
356
- )
353
+ ) # Paths are now handled automatically by default args
357
354
 
358
355
  # Load the model
359
356
  success = self.segmenter.load_model()
@@ -247,11 +247,13 @@ def main():
247
247
  else:
248
248
  # Try to load a default benchmark image
249
249
  try:
250
- default_path = './DeepD3_Benchmark.tif'
251
- print(f"No image path provided, trying to load default: {default_path}")
250
+ from neuro_sam.utils import get_weights_path
251
+ default_path = get_weights_path('DeepD3_Benchmark.tif')
252
+ print(f"No image path provided, loading default: {default_path}")
252
253
  spacing_xyz = (args.x_spacing, args.y_spacing, args.z_spacing)
253
254
  viewer = run_neuro_sam(image_path=default_path, spacing_xyz=spacing_xyz)
254
- except FileNotFoundError:
255
+ except Exception as e:
256
+ print(f"Failed to load default image: {e}")
255
257
  sys.exit(1)
256
258
 
257
259
  print("\nStarted NeuroSAM with anisotropic scaling support!")
@@ -0,0 +1 @@
1
+ # Init file for training module
@@ -0,0 +1,226 @@
1
+ import numpy as np
2
+ import torch
3
+ import os
4
+ import wandb
5
+ from torch.onnx.symbolic_opset11 import hstack
6
+ import torch.nn.functional as F
7
+ from .utils.stream_dendrites import DataGeneratorStream
8
+ from sam2.build_sam import build_sam2
9
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
10
+ import time
11
+ import argparse
12
+ from neuro_sam.utils import get_weights_path
13
+
14
+ def main():
15
+ parser = argparse.ArgumentParser(description="Train Neuro-SAM Dendrite Segmenter")
16
+ parser.add_argument("--ppn", type=int, required=True, help="Positive Points Number")
17
+ parser.add_argument("--pnn", type=int, required=True, help="Negative Points Number")
18
+ parser.add_argument("--model_name", type=str, required=True, choices=['small', 'base_plus', 'large', 'tiny'], help="SAM2 Model Size")
19
+ parser.add_argument("--batch_size", type=int, required=True, help="Batch Size")
20
+ parser.add_argument("--logger", type=str, default="False", help="Use WandB Logger (True/False)")
21
+ parser.add_argument("--data_dir", type=str, default="./data", help="Directory containing .d3set data files")
22
+ args = parser.parse_args()
23
+
24
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
+ print(f"Using device: {device}")
26
+
27
+ TRAINING_DATA_PATH = os.path.join(args.data_dir, "DeepD3_Training.d3set")
28
+ VALIDATION_DATA_PATH = os.path.join(args.data_dir, "DeepD3_Validation.d3set")
29
+
30
+ if not os.path.exists(TRAINING_DATA_PATH):
31
+ raise FileNotFoundError(f"Training data not found at {TRAINING_DATA_PATH}")
32
+
33
+ positive_points = args.ppn
34
+ negative_points = args.pnn
35
+ batch_size = args.batch_size
36
+ logger = (args.logger.lower() == "true")
37
+
38
+ print("Initializing Data Generator...")
39
+ dg_training = DataGeneratorStream(TRAINING_DATA_PATH,
40
+ batch_size=batch_size, # Data processed at once, depends on your GPU
41
+ target_resolution=0.094, # fixed to 94 nm, can be None for mixed resolution training
42
+ min_content=100,
43
+ resizing_size = 128,
44
+ positive_points = positive_points,
45
+ negative_points = negative_points) # images need to have at least 50 segmented px
46
+
47
+ dg_validation = DataGeneratorStream(VALIDATION_DATA_PATH,
48
+ batch_size=batch_size,
49
+ target_resolution=0.094,
50
+ min_content=100,
51
+ augment=False,
52
+ shuffle=False,
53
+ resizing_size = 128,
54
+ positive_points = positive_points,
55
+ negative_points = negative_points)
56
+
57
+ # Load model
58
+ model_name_map = {
59
+ 'large': ("sam2.1_hiera_large.pt", "sam2.1_hiera_l.yaml"),
60
+ 'base_plus': ("sam2.1_hiera_base_plus.pt", "sam2.1_hiera_b+.yaml"),
61
+ 'small': ("sam2.1_hiera_small.pt", "sam2.1_hiera_s.yaml"),
62
+ 'tiny': ("sam2.1_hiera_tiny.pt", "sam2.1_hiera_t.yaml"),
63
+ }
64
+
65
+ ckpt_name, model_cfg = model_name_map[args.model_name]
66
+
67
+ # Try to find weights via util or local defaults
68
+ try:
69
+ sam2_checkpoint = get_weights_path(ckpt_name)
70
+ except:
71
+ # Fallback if not downloadable or found
72
+ sam2_checkpoint = ckpt_name
73
+ print(f"Warning: Could not resolve weight path for {ckpt_name}, assuming local file.")
74
+
75
+ print(f"Loading SAM2 model: {args.model_name} from {sam2_checkpoint}")
76
+ sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device) # load model
77
+
78
+ predictor = SAM2ImagePredictor(sam2_model)
79
+
80
+ # Set training parameters
81
+
82
+ predictor.model.sam_mask_decoder.train(True) # enable training of mask decoder
83
+ predictor.model.sam_prompt_encoder.train(True) # enable training of prompt encoder
84
+ predictor.model.image_encoder.train(True) # enable training of image encoder: For this to work you need to scan the code for "no_grad" and remove them all
85
+ optimizer=torch.optim.AdamW(params=predictor.model.parameters(),lr=1e-5,weight_decay=4e-5)
86
+ scaler = torch.cuda.amp.GradScaler() # mixed precision
87
+
88
+
89
+ time_str = "-".join(["{:0>2}".format(x) for x in time.localtime(time.time())][:-3])
90
+ ckpt_path = f'results/samv2_{args.model_name}_{time_str}'
91
+ if not os.path.exists(ckpt_path): os.makedirs(ckpt_path)
92
+
93
+ if logger:
94
+ wandb.init(
95
+ # set the wandb project where this run will be logged
96
+ project="N-SAMv2",
97
+
98
+ # track hyperparameters and run metadata
99
+ config={
100
+ "architecture": "SAMv2",
101
+ "dataset": "DeepD3",
102
+ "model": args.model_name,
103
+ "epochs": 100000,
104
+ "ckpt_path": ckpt_path,
105
+ "image_size": (1,128,128),
106
+ "min_content": 100,
107
+ "positive_points": positive_points,
108
+ "negative_points": negative_points,
109
+ "batch_size": batch_size,
110
+ "prompt_seed":42
111
+ }
112
+ )
113
+
114
+
115
+ # add val code here
116
+
117
+ def perform_validation(predictor):
118
+ print('Performing Validation')
119
+ mean_iou = []
120
+ mean_dice = []
121
+ mean_loss = []
122
+ with torch.no_grad():
123
+ for i in range(20):
124
+ try:
125
+ n = np.random.randint(len(dg_validation))
126
+ image,mask,input_point, input_label = dg_validation[n]
127
+
128
+ except:
129
+ print('Error in validation batch generation')
130
+ continue
131
+ if mask.shape[0]==0: continue # ignore empty batches
132
+ predictor.set_image_batch(image) # apply SAM image encoder to the image
133
+
134
+ mask_input, unnorm_coords, labels, unnorm_box = predictor._prep_prompts(input_point, input_label, box=None, mask_logits=None, normalize_coords=True)
135
+ sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(points=(unnorm_coords, labels),boxes=None,masks=None,)
136
+
137
+ # mask decoder
138
+
139
+ high_res_features = [feat_level[-1].unsqueeze(0) for feat_level in predictor._features["high_res_feats"]]
140
+ low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder(image_embeddings=predictor._features["image_embed"],image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(),sparse_prompt_embeddings=sparse_embeddings,dense_prompt_embeddings=dense_embeddings,multimask_output=True,repeat_image=False,high_res_features=high_res_features,)
141
+ prd_masks = predictor._transforms.postprocess_masks(low_res_masks, predictor._orig_hw[-1])# Upscale the masks to the original image resolution
142
+
143
+ # Segmentaion Loss caclulation
144
+
145
+ gt_mask = torch.tensor(mask.astype(np.float32)).to(device)
146
+ prd_mask = torch.sigmoid(prd_masks[:, 0])# Turn logit map to probability map
147
+ seg_loss = (-gt_mask * torch.log(prd_mask + 0.00001) - (1 - gt_mask) * torch.log((1 - prd_mask) + 0.00001)).mean() # cross entropy loss
148
+
149
+ # Score loss calculation (intersection over union) IOU
150
+
151
+ inter = (gt_mask * (prd_mask > 0.5)).sum(1).sum(1)
152
+ iou = inter / (gt_mask.sum(1).sum(1) + (prd_mask > 0.5).sum(1).sum(1) - inter)
153
+ total_sum = gt_mask.sum(1).sum(1) + (prd_mask > 0.5).sum(1).sum(1)
154
+ dicee = ((2 * inter) / total_sum).mean()
155
+ score_loss = torch.abs(prd_scores[:, 0] - iou).mean()
156
+ loss=seg_loss+score_loss*0.05 # mix losses
157
+
158
+ mean_iou.append(iou.cpu().detach().numpy())
159
+ mean_dice.append(dicee.cpu().detach().numpy())
160
+ mean_loss.append(loss.cpu().detach().numpy())
161
+
162
+ return np.array(mean_iou).mean(), np.array(mean_dice).mean(), np.array(mean_loss).mean()
163
+
164
+
165
+ for itr in range(100000):
166
+ epoch_dice = epoch_iou = epoch_loss = []
167
+ with torch.cuda.amp.autocast(): # cast to mix precision
168
+ n = np.random.randint(len(dg_training))
169
+ try:
170
+ image,mask,input_point, input_label = dg_training[n]
171
+ except:
172
+ print('Error in training batch')
173
+ continue
174
+ if mask.shape[0]==0: continue # ignore empty batches
175
+ predictor.set_image_batch(image) # apply SAM image encoder to the image
176
+
177
+ # prompt encoding
178
+
179
+ mask_input, unnorm_coords, labels, unnorm_box = predictor._prep_prompts(input_point, input_label, box=None, mask_logits=None, normalize_coords=True)
180
+ sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(points=(unnorm_coords, labels),boxes=None,masks=None,)
181
+
182
+ # mask decoder
183
+
184
+ high_res_features = [feat_level[-1].unsqueeze(0) for feat_level in predictor._features["high_res_feats"]]
185
+ low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder(image_embeddings=predictor._features["image_embed"],image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(),sparse_prompt_embeddings=sparse_embeddings,dense_prompt_embeddings=dense_embeddings,multimask_output=True,repeat_image=False,high_res_features=high_res_features,)
186
+ prd_masks = predictor._transforms.postprocess_masks(low_res_masks, predictor._orig_hw[-1])# Upscale the masks to the original image resolution
187
+
188
+ gt_mask = torch.tensor(mask.astype(np.float32)).to(device)
189
+ prd_mask = torch.sigmoid(prd_masks[:, 0])# Turn logit map to probability map
190
+ seg_loss = (-gt_mask * torch.log(prd_mask + 0.00001) - (1 - gt_mask) * torch.log((1 - prd_mask) + 0.00001)).mean() # cross entropy loss
191
+
192
+ # Score loss calculation (intersection over union) IOU
193
+
194
+ inter = (gt_mask * (prd_mask > 0.5)).sum(1).sum(1)
195
+ iou = inter / (gt_mask.sum(1).sum(1) + (prd_mask > 0.5).sum(1).sum(1) - inter)
196
+ total_sum = gt_mask.sum(1).sum(1) + (prd_mask > 0.5).sum(1).sum(1)
197
+ dicee = ((2 * inter) / total_sum).mean()
198
+ score_loss = torch.abs(prd_scores[:, 0] - iou).mean()
199
+ loss=seg_loss+score_loss*0.05 # mix losses
200
+
201
+ predictor.model.zero_grad() # empty gradient
202
+ scaler.scale(loss).backward() # Backpropogate
203
+ scaler.step(optimizer)
204
+ scaler.update() # Mix precision
205
+
206
+
207
+ # Display results
208
+ if itr==0: mean_iou=0
209
+ mean_iou = mean_iou * 0.99 + 0.01 * np.mean(iou.cpu().detach().numpy())
210
+ if logger:
211
+ wandb.log({'step': itr, 'loss':loss, 'iou':mean_iou, 'dice_score':dicee})
212
+ print(f'step: {itr}, loss: {loss}, iou: {mean_iou}, dice_score: {dicee}')
213
+
214
+ if itr%500 == 0:
215
+
216
+ val_iou, val_dice, val_loss = perform_validation(predictor)
217
+ if logger:
218
+ wandb.log({'step': itr, 'val_loss':val_loss, 'val_iou':val_iou, 'val_dice_score':val_dice})
219
+ print(f'step: {itr}, val_loss: {val_loss}, val_iou: {val_iou}, val_dice_score: {val_dice}')
220
+
221
+ torch.save(predictor.model.state_dict(), f"{ckpt_path}/model_{itr}.torch") # save model
222
+ if logger:
223
+ wandb.finish()
224
+
225
+ if __name__ == "__main__":
226
+ main()
@@ -0,0 +1 @@
1
+ # Init file for training utils
@@ -0,0 +1,78 @@
1
+ import numpy as np
2
+ from scipy import ndimage
3
+ import cv2, random
4
+ from collections import *
5
+ from itertools import *
6
+ from functools import *
7
+
8
+ class PromptGeneration:
9
+ def __init__(self, random_seed=0, neg_range=(3, 20), min_content = 20):
10
+ if random_seed:
11
+ random.seed(random_seed)
12
+ np.random.seed(random_seed)
13
+
14
+ self.neg_range = neg_range
15
+
16
+ self.min_area = min_content
17
+
18
+ def get_labelmap(self, label):
19
+ structure = ndimage.generate_binary_structure(2, 2)
20
+ labelmaps, connected_num = ndimage.label(label, structure=structure)
21
+
22
+ label = np.zeros_like(labelmaps)
23
+ for i in range(1, 1+connected_num):
24
+ if np.sum(labelmaps==i) >= self.min_area: label += np.where(labelmaps==i, 255, 0)
25
+
26
+ structure = ndimage.generate_binary_structure(2, 2)
27
+ labelmaps, connected_num = ndimage.label(label, structure=structure)
28
+
29
+ return labelmaps, connected_num
30
+
31
+ def search_negative_region_numpy(self, labelmap):
32
+ inner_range, outer_range = self.neg_range
33
+ def search(neg_range):
34
+ kernel = np.ones((neg_range * 2 + 1, neg_range * 2 + 1), np.uint8)
35
+ negative_region = cv2.dilate(labelmap, kernel, iterations=1)
36
+ mx = labelmap.max() + 1
37
+ labelmap_r = (mx - labelmap) * np.minimum(1, labelmap)
38
+ r = cv2.dilate(labelmap_r, kernel, iterations=1)
39
+ negative_region_r = (r.astype(np.int32) - mx) * np.minimum(1, r)
40
+ diff = negative_region.astype(np.int32) + negative_region_r
41
+ overlap = np.minimum(1, np.abs(diff).astype(np.uint8))
42
+ return negative_region - overlap - labelmap
43
+ return search(outer_range) - search(inner_range)
44
+
45
+ def get_prompt_points(self, label_mask, ppp, ppn):
46
+ label_mask_cp = np.copy(label_mask)
47
+ label_mask_cp[label_mask_cp >= 1] = 1
48
+ labelmaps, connected_num = self.get_labelmap(label_mask_cp)
49
+
50
+ coord_positive, coord_negative = [], []
51
+
52
+ connected_components = list(range(1, connected_num+1))
53
+ random.shuffle(connected_components)
54
+
55
+ for i in connected_components:
56
+ cc = np.copy(labelmaps)
57
+ cc[cc!=i] = 0
58
+ cc[cc==i] = 1
59
+ if ppp:
60
+ coord_positive.append(random.choice([[y, x] for x, y in np.argwhere(cc == 1)]))
61
+ ppp -= 1
62
+
63
+ random.shuffle(connected_components)
64
+ for i in connected_components:
65
+ cc = np.copy(labelmaps)
66
+ cc[cc!=i] = 0
67
+ cc[cc==i] = 1
68
+ negative_region = self.search_negative_region_numpy(cc.astype(np.uint8))
69
+ negative_region = negative_region * (1 - label_mask_cp)
70
+ if ppn:
71
+ coord_negative.append(random.choice([[y, x] for x, y in np.argwhere(negative_region == 1)]))
72
+ ppn -= 1
73
+
74
+ negative_region = self.search_negative_region_numpy(label_mask_cp)
75
+
76
+ if ppp: coord_positive += random.sample([[y, x] for x, y in np.argwhere(label_mask_cp == 1)], ppp)
77
+ if ppn: coord_negative += random.sample([[y, x] for x, y in np.argwhere(negative_region == 1)], ppn)
78
+ return coord_positive, coord_negative