neuro-sam 0.1.7__py3-none-any.whl → 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -509,10 +509,10 @@ class PathTracingWidget(QWidget):
509
509
  self.scaler.original_spacing_xyz[0] / x_nm # X
510
510
  ])
511
511
 
512
- self.scaling_status.setText(
513
- f"Pending: X={x_nm:.1f}, Y={y_nm:.1f}, Z={z_nm:.1f} nm\n"
514
- f"Scale factors (Z,Y,X): {temp_scale_factors[0]:.3f}, {temp_scale_factors[1]:.3f}, {temp_scale_factors[2]:.3f}"
515
- )
512
+ # self.scaling_status.setText(
513
+ # f"Pending: X={x_nm:.1f}, Y={y_nm:.1f}, Z={z_nm:.1f} nm\n"
514
+ # f"Scale factors (Z,Y,X): {temp_scale_factors[0]:.3f}, {temp_scale_factors[1]:.3f}, {temp_scale_factors[2]:.3f}"
515
+ # )
516
516
 
517
517
  def _apply_scaling(self):
518
518
  """Apply current scaling settings"""
@@ -21,6 +21,8 @@ from neuro_sam.punet.punet_inference import run_inference_volume
21
21
 
22
22
 
23
23
 
24
+ from neuro_sam.utils import get_weights_path
25
+
24
26
  class PunetSpineSegmentationWidget(QWidget):
25
27
  """
26
28
  Widget for spine segmentation using Probabilistic U-Net.
@@ -36,7 +38,7 @@ class PunetSpineSegmentationWidget(QWidget):
36
38
 
37
39
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
40
  self.model = None
39
- self.model_path = "punet/punet_best.pth" # Default relative path
41
+ self.model_path = get_weights_path("punet_best.pth") # Auto-download default weights
40
42
 
41
43
  # Connect custom progress signal
42
44
  self.progress_signal.connect(self._on_worker_progress)
@@ -0,0 +1 @@
1
+ # Init file for training module
@@ -0,0 +1,226 @@
1
+ import numpy as np
2
+ import torch
3
+ import os
4
+ import wandb
5
+ from torch.onnx.symbolic_opset11 import hstack
6
+ import torch.nn.functional as F
7
+ from .utils.stream_dendrites import DataGeneratorStream
8
+ from sam2.build_sam import build_sam2
9
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
10
+ import time
11
+ import argparse
12
+ from neuro_sam.utils import get_weights_path
13
+
14
+ def main():
15
+ parser = argparse.ArgumentParser(description="Train Neuro-SAM Dendrite Segmenter")
16
+ parser.add_argument("--ppn", type=int, required=True, help="Positive Points Number")
17
+ parser.add_argument("--pnn", type=int, required=True, help="Negative Points Number")
18
+ parser.add_argument("--model_name", type=str, required=True, choices=['small', 'base_plus', 'large', 'tiny'], help="SAM2 Model Size")
19
+ parser.add_argument("--batch_size", type=int, required=True, help="Batch Size")
20
+ parser.add_argument("--logger", type=str, default="False", help="Use WandB Logger (True/False)")
21
+ parser.add_argument("--data_dir", type=str, default="./data", help="Directory containing .d3set data files")
22
+ args = parser.parse_args()
23
+
24
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
+ print(f"Using device: {device}")
26
+
27
+ TRAINING_DATA_PATH = os.path.join(args.data_dir, "DeepD3_Training.d3set")
28
+ VALIDATION_DATA_PATH = os.path.join(args.data_dir, "DeepD3_Validation.d3set")
29
+
30
+ if not os.path.exists(TRAINING_DATA_PATH):
31
+ raise FileNotFoundError(f"Training data not found at {TRAINING_DATA_PATH}")
32
+
33
+ positive_points = args.ppn
34
+ negative_points = args.pnn
35
+ batch_size = args.batch_size
36
+ logger = (args.logger.lower() == "true")
37
+
38
+ print("Initializing Data Generator...")
39
+ dg_training = DataGeneratorStream(TRAINING_DATA_PATH,
40
+ batch_size=batch_size, # Data processed at once, depends on your GPU
41
+ target_resolution=0.094, # fixed to 94 nm, can be None for mixed resolution training
42
+ min_content=100,
43
+ resizing_size = 128,
44
+ positive_points = positive_points,
45
+ negative_points = negative_points) # images need to have at least 50 segmented px
46
+
47
+ dg_validation = DataGeneratorStream(VALIDATION_DATA_PATH,
48
+ batch_size=batch_size,
49
+ target_resolution=0.094,
50
+ min_content=100,
51
+ augment=False,
52
+ shuffle=False,
53
+ resizing_size = 128,
54
+ positive_points = positive_points,
55
+ negative_points = negative_points)
56
+
57
+ # Load model
58
+ model_name_map = {
59
+ 'large': ("sam2.1_hiera_large.pt", "sam2.1_hiera_l.yaml"),
60
+ 'base_plus': ("sam2.1_hiera_base_plus.pt", "sam2.1_hiera_b+.yaml"),
61
+ 'small': ("sam2.1_hiera_small.pt", "sam2.1_hiera_s.yaml"),
62
+ 'tiny': ("sam2.1_hiera_tiny.pt", "sam2.1_hiera_t.yaml"),
63
+ }
64
+
65
+ ckpt_name, model_cfg = model_name_map[args.model_name]
66
+
67
+ # Try to find weights via util or local defaults
68
+ try:
69
+ sam2_checkpoint = get_weights_path(ckpt_name)
70
+ except:
71
+ # Fallback if not downloadable or found
72
+ sam2_checkpoint = ckpt_name
73
+ print(f"Warning: Could not resolve weight path for {ckpt_name}, assuming local file.")
74
+
75
+ print(f"Loading SAM2 model: {args.model_name} from {sam2_checkpoint}")
76
+ sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device) # load model
77
+
78
+ predictor = SAM2ImagePredictor(sam2_model)
79
+
80
+ # Set training parameters
81
+
82
+ predictor.model.sam_mask_decoder.train(True) # enable training of mask decoder
83
+ predictor.model.sam_prompt_encoder.train(True) # enable training of prompt encoder
84
+ predictor.model.image_encoder.train(True) # enable training of image encoder: For this to work you need to scan the code for "no_grad" and remove them all
85
+ optimizer=torch.optim.AdamW(params=predictor.model.parameters(),lr=1e-5,weight_decay=4e-5)
86
+ scaler = torch.cuda.amp.GradScaler() # mixed precision
87
+
88
+
89
+ time_str = "-".join(["{:0>2}".format(x) for x in time.localtime(time.time())][:-3])
90
+ ckpt_path = f'results/samv2_{args.model_name}_{time_str}'
91
+ if not os.path.exists(ckpt_path): os.makedirs(ckpt_path)
92
+
93
+ if logger:
94
+ wandb.init(
95
+ # set the wandb project where this run will be logged
96
+ project="N-SAMv2",
97
+
98
+ # track hyperparameters and run metadata
99
+ config={
100
+ "architecture": "SAMv2",
101
+ "dataset": "DeepD3",
102
+ "model": args.model_name,
103
+ "epochs": 100000,
104
+ "ckpt_path": ckpt_path,
105
+ "image_size": (1,128,128),
106
+ "min_content": 100,
107
+ "positive_points": positive_points,
108
+ "negative_points": negative_points,
109
+ "batch_size": batch_size,
110
+ "prompt_seed":42
111
+ }
112
+ )
113
+
114
+
115
+ # add val code here
116
+
117
+ def perform_validation(predictor):
118
+ print('Performing Validation')
119
+ mean_iou = []
120
+ mean_dice = []
121
+ mean_loss = []
122
+ with torch.no_grad():
123
+ for i in range(20):
124
+ try:
125
+ n = np.random.randint(len(dg_validation))
126
+ image,mask,input_point, input_label = dg_validation[n]
127
+
128
+ except:
129
+ print('Error in validation batch generation')
130
+ continue
131
+ if mask.shape[0]==0: continue # ignore empty batches
132
+ predictor.set_image_batch(image) # apply SAM image encoder to the image
133
+
134
+ mask_input, unnorm_coords, labels, unnorm_box = predictor._prep_prompts(input_point, input_label, box=None, mask_logits=None, normalize_coords=True)
135
+ sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(points=(unnorm_coords, labels),boxes=None,masks=None,)
136
+
137
+ # mask decoder
138
+
139
+ high_res_features = [feat_level[-1].unsqueeze(0) for feat_level in predictor._features["high_res_feats"]]
140
+ low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder(image_embeddings=predictor._features["image_embed"],image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(),sparse_prompt_embeddings=sparse_embeddings,dense_prompt_embeddings=dense_embeddings,multimask_output=True,repeat_image=False,high_res_features=high_res_features,)
141
+ prd_masks = predictor._transforms.postprocess_masks(low_res_masks, predictor._orig_hw[-1])# Upscale the masks to the original image resolution
142
+
143
+ # Segmentaion Loss caclulation
144
+
145
+ gt_mask = torch.tensor(mask.astype(np.float32)).to(device)
146
+ prd_mask = torch.sigmoid(prd_masks[:, 0])# Turn logit map to probability map
147
+ seg_loss = (-gt_mask * torch.log(prd_mask + 0.00001) - (1 - gt_mask) * torch.log((1 - prd_mask) + 0.00001)).mean() # cross entropy loss
148
+
149
+ # Score loss calculation (intersection over union) IOU
150
+
151
+ inter = (gt_mask * (prd_mask > 0.5)).sum(1).sum(1)
152
+ iou = inter / (gt_mask.sum(1).sum(1) + (prd_mask > 0.5).sum(1).sum(1) - inter)
153
+ total_sum = gt_mask.sum(1).sum(1) + (prd_mask > 0.5).sum(1).sum(1)
154
+ dicee = ((2 * inter) / total_sum).mean()
155
+ score_loss = torch.abs(prd_scores[:, 0] - iou).mean()
156
+ loss=seg_loss+score_loss*0.05 # mix losses
157
+
158
+ mean_iou.append(iou.cpu().detach().numpy())
159
+ mean_dice.append(dicee.cpu().detach().numpy())
160
+ mean_loss.append(loss.cpu().detach().numpy())
161
+
162
+ return np.array(mean_iou).mean(), np.array(mean_dice).mean(), np.array(mean_loss).mean()
163
+
164
+
165
+ for itr in range(100000):
166
+ epoch_dice = epoch_iou = epoch_loss = []
167
+ with torch.cuda.amp.autocast(): # cast to mix precision
168
+ n = np.random.randint(len(dg_training))
169
+ try:
170
+ image,mask,input_point, input_label = dg_training[n]
171
+ except:
172
+ print('Error in training batch')
173
+ continue
174
+ if mask.shape[0]==0: continue # ignore empty batches
175
+ predictor.set_image_batch(image) # apply SAM image encoder to the image
176
+
177
+ # prompt encoding
178
+
179
+ mask_input, unnorm_coords, labels, unnorm_box = predictor._prep_prompts(input_point, input_label, box=None, mask_logits=None, normalize_coords=True)
180
+ sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(points=(unnorm_coords, labels),boxes=None,masks=None,)
181
+
182
+ # mask decoder
183
+
184
+ high_res_features = [feat_level[-1].unsqueeze(0) for feat_level in predictor._features["high_res_feats"]]
185
+ low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder(image_embeddings=predictor._features["image_embed"],image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(),sparse_prompt_embeddings=sparse_embeddings,dense_prompt_embeddings=dense_embeddings,multimask_output=True,repeat_image=False,high_res_features=high_res_features,)
186
+ prd_masks = predictor._transforms.postprocess_masks(low_res_masks, predictor._orig_hw[-1])# Upscale the masks to the original image resolution
187
+
188
+ gt_mask = torch.tensor(mask.astype(np.float32)).to(device)
189
+ prd_mask = torch.sigmoid(prd_masks[:, 0])# Turn logit map to probability map
190
+ seg_loss = (-gt_mask * torch.log(prd_mask + 0.00001) - (1 - gt_mask) * torch.log((1 - prd_mask) + 0.00001)).mean() # cross entropy loss
191
+
192
+ # Score loss calculation (intersection over union) IOU
193
+
194
+ inter = (gt_mask * (prd_mask > 0.5)).sum(1).sum(1)
195
+ iou = inter / (gt_mask.sum(1).sum(1) + (prd_mask > 0.5).sum(1).sum(1) - inter)
196
+ total_sum = gt_mask.sum(1).sum(1) + (prd_mask > 0.5).sum(1).sum(1)
197
+ dicee = ((2 * inter) / total_sum).mean()
198
+ score_loss = torch.abs(prd_scores[:, 0] - iou).mean()
199
+ loss=seg_loss+score_loss*0.05 # mix losses
200
+
201
+ predictor.model.zero_grad() # empty gradient
202
+ scaler.scale(loss).backward() # Backpropogate
203
+ scaler.step(optimizer)
204
+ scaler.update() # Mix precision
205
+
206
+
207
+ # Display results
208
+ if itr==0: mean_iou=0
209
+ mean_iou = mean_iou * 0.99 + 0.01 * np.mean(iou.cpu().detach().numpy())
210
+ if logger:
211
+ wandb.log({'step': itr, 'loss':loss, 'iou':mean_iou, 'dice_score':dicee})
212
+ print(f'step: {itr}, loss: {loss}, iou: {mean_iou}, dice_score: {dicee}')
213
+
214
+ if itr%500 == 0:
215
+
216
+ val_iou, val_dice, val_loss = perform_validation(predictor)
217
+ if logger:
218
+ wandb.log({'step': itr, 'val_loss':val_loss, 'val_iou':val_iou, 'val_dice_score':val_dice})
219
+ print(f'step: {itr}, val_loss: {val_loss}, val_iou: {val_iou}, val_dice_score: {val_dice}')
220
+
221
+ torch.save(predictor.model.state_dict(), f"{ckpt_path}/model_{itr}.torch") # save model
222
+ if logger:
223
+ wandb.finish()
224
+
225
+ if __name__ == "__main__":
226
+ main()
@@ -0,0 +1 @@
1
+ # Init file for training utils
@@ -0,0 +1,78 @@
1
+ import numpy as np
2
+ from scipy import ndimage
3
+ import cv2, random
4
+ from collections import *
5
+ from itertools import *
6
+ from functools import *
7
+
8
+ class PromptGeneration:
9
+ def __init__(self, random_seed=0, neg_range=(3, 20), min_content = 20):
10
+ if random_seed:
11
+ random.seed(random_seed)
12
+ np.random.seed(random_seed)
13
+
14
+ self.neg_range = neg_range
15
+
16
+ self.min_area = min_content
17
+
18
+ def get_labelmap(self, label):
19
+ structure = ndimage.generate_binary_structure(2, 2)
20
+ labelmaps, connected_num = ndimage.label(label, structure=structure)
21
+
22
+ label = np.zeros_like(labelmaps)
23
+ for i in range(1, 1+connected_num):
24
+ if np.sum(labelmaps==i) >= self.min_area: label += np.where(labelmaps==i, 255, 0)
25
+
26
+ structure = ndimage.generate_binary_structure(2, 2)
27
+ labelmaps, connected_num = ndimage.label(label, structure=structure)
28
+
29
+ return labelmaps, connected_num
30
+
31
+ def search_negative_region_numpy(self, labelmap):
32
+ inner_range, outer_range = self.neg_range
33
+ def search(neg_range):
34
+ kernel = np.ones((neg_range * 2 + 1, neg_range * 2 + 1), np.uint8)
35
+ negative_region = cv2.dilate(labelmap, kernel, iterations=1)
36
+ mx = labelmap.max() + 1
37
+ labelmap_r = (mx - labelmap) * np.minimum(1, labelmap)
38
+ r = cv2.dilate(labelmap_r, kernel, iterations=1)
39
+ negative_region_r = (r.astype(np.int32) - mx) * np.minimum(1, r)
40
+ diff = negative_region.astype(np.int32) + negative_region_r
41
+ overlap = np.minimum(1, np.abs(diff).astype(np.uint8))
42
+ return negative_region - overlap - labelmap
43
+ return search(outer_range) - search(inner_range)
44
+
45
+ def get_prompt_points(self, label_mask, ppp, ppn):
46
+ label_mask_cp = np.copy(label_mask)
47
+ label_mask_cp[label_mask_cp >= 1] = 1
48
+ labelmaps, connected_num = self.get_labelmap(label_mask_cp)
49
+
50
+ coord_positive, coord_negative = [], []
51
+
52
+ connected_components = list(range(1, connected_num+1))
53
+ random.shuffle(connected_components)
54
+
55
+ for i in connected_components:
56
+ cc = np.copy(labelmaps)
57
+ cc[cc!=i] = 0
58
+ cc[cc==i] = 1
59
+ if ppp:
60
+ coord_positive.append(random.choice([[y, x] for x, y in np.argwhere(cc == 1)]))
61
+ ppp -= 1
62
+
63
+ random.shuffle(connected_components)
64
+ for i in connected_components:
65
+ cc = np.copy(labelmaps)
66
+ cc[cc!=i] = 0
67
+ cc[cc==i] = 1
68
+ negative_region = self.search_negative_region_numpy(cc.astype(np.uint8))
69
+ negative_region = negative_region * (1 - label_mask_cp)
70
+ if ppn:
71
+ coord_negative.append(random.choice([[y, x] for x, y in np.argwhere(negative_region == 1)]))
72
+ ppn -= 1
73
+
74
+ negative_region = self.search_negative_region_numpy(label_mask_cp)
75
+
76
+ if ppp: coord_positive += random.sample([[y, x] for x, y in np.argwhere(label_mask_cp == 1)], ppp)
77
+ if ppn: coord_negative += random.sample([[y, x] for x, y in np.argwhere(negative_region == 1)], ppn)
78
+ return coord_positive, coord_negative
@@ -0,0 +1,299 @@
1
+ import numpy as np
2
+ import flammkuchen as fl
3
+ from tensorflow.keras.utils import Sequence
4
+ from .prompt_generation_dendrites import PromptGeneration
5
+ import cv2
6
+ import albumentations as A
7
+ import random
8
+
9
+ class DataGeneratorStream(Sequence):
10
+ def __init__(self, fn, batch_size, samples_per_epoch=50000, size=(1, 128, 128), target_resolution=None, augment=True,
11
+ shuffle=True, seed=42, normalize=[-1, 1], resizing_size=1024, min_content=0., positive_points=1, negative_points=1
12
+ ):
13
+ """Data Generator that streams data dynamically for training DeepD3.
14
+
15
+ Args:
16
+ fn (str): The path to the training data file
17
+ batch_size (int): Batch size for training deep neural networks
18
+ samples_per_epoch (int, optional): Samples used in each epoch. Defaults to 50000.
19
+ size (tuple, optional): Shape of a single sample. Defaults to (1, 128, 128).
20
+ target_resolution (float, optional): Target resolution in microns. Defaults to None.
21
+ augment (bool, optional): Enables augmenting the data. Defaults to True.
22
+ shuffle (bool, optional): Enabled shuffling the data. Defaults to True.
23
+ seed (int, optional): Creates pseudorandom numbers for shuffling. Defaults to 42.
24
+ normalize (list, optional): Values range when normalizing data. Defaults to [-1, 1].
25
+ min_content (float, optional): Minimum content in image (annotated dendrite or spine), not considered if 0. Defaults to 0.
26
+ """
27
+
28
+ # Save settings
29
+ self.batch_size = batch_size
30
+ self.augment = augment
31
+ self.fn = fn
32
+ self.shuffle = shuffle
33
+ self.aug = self._get_augmenter()
34
+ self.add_gaus_noise = self._add_gaus_noise()
35
+ self.seed = seed
36
+ self.normalize = normalize
37
+ self.samples_per_epoch = samples_per_epoch
38
+ self.size = size
39
+ self.target_resolution = target_resolution
40
+ self.min_content = min_content
41
+ self.resizing_size = resizing_size
42
+ self.ppn = positive_points
43
+ self.pnn = negative_points
44
+
45
+ self.d = fl.load(self.fn)
46
+ self.data = self.d['data']
47
+ self.meta = self.d['meta']
48
+
49
+ # Seed randomness
50
+ random.seed(self.seed)
51
+ np.random.seed(self.seed)
52
+
53
+ self.on_epoch_end()
54
+ self.pg = PromptGeneration(random_seed=42, neg_range=(3, 9))
55
+
56
+ def __len__(self):
57
+ """Denotes the number of batches per epoch"""
58
+ return self.samples_per_epoch // self.batch_size
59
+
60
+ def __getitem__(self, index):
61
+ """Generate one batch of data
62
+
63
+ Parameters
64
+ ----------
65
+ index : int
66
+ batch index in image/label id list
67
+
68
+ Returns
69
+ -------
70
+ tuple
71
+ Contains two numpy arrays,
72
+ each of shape (batch_size, height, width, 1).
73
+ """
74
+ X = []
75
+ Y0 = []
76
+ Y1 = []
77
+ pps = []
78
+ pts = []
79
+ eps = 1e-5
80
+
81
+ if self.shuffle is False:
82
+ np.random.seed(index)
83
+
84
+ # Create all pairs in a given batch
85
+ for i in range(self.batch_size):
86
+ # Retrieve a single sample pair
87
+ image, dendrite, spines = self.getSample()
88
+ image = cv2.resize(image, (self.resizing_size, self.resizing_size), interpolation=cv2.INTER_LINEAR)
89
+ dendrite = cv2.resize(dendrite.astype(np.uint8), (self.resizing_size, self.resizing_size), interpolation=cv2.INTER_LINEAR)
90
+ spines = cv2.resize(spines.astype(np.uint8), (self.resizing_size, self.resizing_size), interpolation=cv2.INTER_LINEAR)
91
+
92
+ # Augmenting the data
93
+ if self.augment:
94
+ augmented = self.aug(image=image,
95
+ mask1=dendrite,
96
+ mask2=spines) #augment image
97
+
98
+ image = augmented['image']
99
+ dendrite = augmented['mask1']
100
+ spines = augmented['mask2']
101
+ try:
102
+ pp, pt = self.extract_points(dendrite)
103
+
104
+ augmented = self.add_gaus_noise(image=image,
105
+ mask1=dendrite,
106
+ mask2=spines) #augment image
107
+
108
+ image = augmented['image']
109
+ dendrite = augmented['mask1']
110
+ spines = augmented['mask2']
111
+ except:
112
+ pass
113
+
114
+ else:
115
+ dendrite = dendrite
116
+ spines = spines
117
+ try:
118
+ pp, pt = self.extract_points(dendrite)
119
+ except:
120
+ return False
121
+
122
+ # Min/max scaling
123
+ image = (image.astype(np.float32) - image.min()) / (image.max() - image.min() + eps)
124
+ # Shifting and scaling
125
+ # image = image * (self.normalize[1]-self.normalize[0]) + self.normalize[0]
126
+
127
+ X.append(cv2.cvtColor(image, cv2.COLOR_GRAY2RGB))
128
+ Y0.append(dendrite.astype(np.float32) / (dendrite.max() + eps))
129
+ Y1.append(spines.astype(np.float32) / (spines.max() + eps)) # to ensure binary targets
130
+ pps.append(pp)
131
+ pts.append(pt)
132
+
133
+ # return np.asarray(X, dtype=np.float32)[..., None], (np.asarray(Y0, dtype=np.float32)[..., None],
134
+ # np.asarray(Y1, dtype=np.float32)[..., None]), (np.asarray(pps), np.asarray(pts))
135
+ return np.asarray(X, dtype=np.float32), np.asarray(Y0, dtype=np.float32), np.asarray(pps), np.asarray(pts)
136
+
137
+ def extract_points(self, dendrite):
138
+ ppp, ppn = self.ppn,self.pnn
139
+ coord_positive, coord_negative = self.pg.get_prompt_points(dendrite, ppp, ppn)
140
+ pp = coord_positive + coord_negative
141
+ pt = [1] * len(coord_positive) + [0] * len(coord_negative)
142
+ pp = np.array(pp, dtype=np.float32)
143
+ pt = np.array(pt, dtype=np.int32)
144
+ return pp, pt
145
+
146
+
147
+ def _get_augmenter(self):
148
+ """Defines used augmentations"""
149
+ aug = A.Compose([
150
+ A.RandomBrightnessContrast(p=0.25),
151
+ A.Rotate(limit=10, border_mode=cv2.BORDER_REFLECT, p=0.5),
152
+ A.RandomRotate90(p=0.5),
153
+ A.HorizontalFlip(p=0.5),
154
+ A.VerticalFlip(p=0.5),
155
+ A.Blur(p=0.2)],p=1,
156
+ #A.GaussNoise(p=0.5)], p=1,
157
+ additional_targets={
158
+ 'mask1': 'mask',
159
+ 'mask2': 'mask'
160
+ })
161
+ return aug
162
+
163
+ def _add_gaus_noise(self):
164
+ aug = A.Compose([
165
+ A.GaussNoise(p=0.5)], p=1,
166
+ additional_targets={
167
+ 'mask1': 'mask',
168
+ 'mask2': 'mask'
169
+ })
170
+ return aug
171
+
172
+
173
+ def getSample(self, squeeze=True):
174
+ """Get a sample from the provided data
175
+
176
+ Args:
177
+ squeeze (bool, optional): if plane is 2D, skip 3D. Defaults to True.
178
+
179
+ Returns:
180
+ list(np.ndarray, np.ndarray, np.ndarray): stack image with respective labels
181
+ """
182
+ while True:
183
+ r = self._getSample(squeeze)
184
+
185
+ # If sample was successfully generated
186
+ # and we don't care about the content
187
+ if r is not None and self.min_content == 0:
188
+ return r
189
+
190
+ # If sample was successfully generated
191
+ # and we do care about the content
192
+ elif r is not None:
193
+ # In either or both annotation should be at least `min_content` pixels
194
+ # that are being labelled.
195
+ if (r[1]).sum() > self.min_content and (r[2]).sum() > self.min_content:
196
+ return r
197
+ else:
198
+ continue
199
+
200
+ else:
201
+ continue
202
+
203
+ def _getSample(self, squeeze=True):
204
+ """Retrieves a sample
205
+
206
+ Args:
207
+ squeeze (bool, optional): Squeezes return shape. Defaults to True.
208
+
209
+ Returns:
210
+ tuple: Tuple of stack (X), dendrite (Y0) and spines (Y1)
211
+ """
212
+ # Adjust for 2 images
213
+ if len(self.size) == 2:
214
+ size = (1,) + self.size
215
+
216
+ else:
217
+ size = self.size
218
+
219
+ # sample random stack
220
+ r_stack = np.random.choice(len(self.meta))
221
+
222
+ target_h = size[1]
223
+ target_w = size[2]
224
+
225
+
226
+ if self.target_resolution is None:
227
+ # Keep everything as is
228
+ scaling = 1
229
+ h = target_h
230
+ w = target_w
231
+
232
+ else:
233
+ # Computing scaling factor
234
+ scaling = self.target_resolution / self.meta.iloc[r_stack].Resolution_XY
235
+
236
+ # Compute the height and width and random offsets
237
+ h = round(scaling * target_h)
238
+ w = round(scaling * target_w)
239
+
240
+ # Correct for stack dimensions
241
+ if self.meta.iloc[r_stack].Width-w == 0:
242
+ x = 0
243
+
244
+ elif self.meta.iloc[r_stack].Width-w < 0:
245
+ return
246
+
247
+ else:
248
+ x = np.random.choice(self.meta.iloc[r_stack].Width-w)
249
+
250
+ # Correct for stack dimensions
251
+ if self.meta.iloc[r_stack].Height-h == 0:
252
+ y = 0
253
+
254
+ elif self.meta.iloc[r_stack].Height-h < 0:
255
+ return
256
+
257
+ else:
258
+ y = np.random.choice(self.meta.iloc[r_stack].Height-h)
259
+
260
+ ## Select random plane + range
261
+ r_plane = np.random.choice(self.meta.iloc[r_stack].Depth-size[0]+1)
262
+
263
+ z_begin = r_plane
264
+ z_end = r_plane+size[0]
265
+
266
+
267
+ # Scale if neccessary to the correct dimensions
268
+ tmp_stack = self.data['stacks'][f'x{r_stack}'][z_begin:z_end, y:y+h, x:x+w]
269
+ tmp_dendrites = self.data['dendrites'][f'x{r_stack}'][z_begin:z_end, y:y+h, x:x+w]
270
+ tmp_spines = self.data['spines'][f'x{r_stack}'][z_begin:z_end, y:y+h, x:x+w]
271
+
272
+ # Data needs to be rescaled
273
+ if scaling != 1:
274
+ return_stack = []
275
+ return_dendrites = []
276
+ return_spines = []
277
+
278
+ # Do this for each plane
279
+ # and ensure that OpenCV is happy
280
+ for i in range(tmp_stack.shape[0]):
281
+ return_stack.append(cv2.resize(tmp_stack[i], (target_h, target_w)))
282
+ return_dendrites.append(cv2.resize(tmp_dendrites[i].astype(np.uint8), (target_h, target_w)).astype(bool))
283
+ return_spines.append(cv2.resize(tmp_spines[i].astype(np.uint8), (target_h, target_w)).astype(bool))
284
+
285
+ return_stack = np.asarray(return_stack)
286
+ return_dendrites = np.asarray(return_dendrites)
287
+ return_spines = np.asarray(return_spines)
288
+
289
+ else:
290
+ return_stack = tmp_stack
291
+ return_dendrites = tmp_dendrites
292
+ return_spines = tmp_spines
293
+
294
+ if squeeze:
295
+ # Return sample
296
+ return return_stack.squeeze(), return_dendrites.squeeze(), return_spines.squeeze()
297
+
298
+ else:
299
+ return return_stack, return_dendrites, return_spines
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: neuro-sam
3
- Version: 0.1.7
3
+ Version: 0.1.8
4
4
  Summary: Neuro-SAM: Foundation Models for Dendrite and Dendritic Spine Segmentation
5
5
  Author-email: Nipun Arora <nipunarora8@yahoo.com>
6
6
  License: MIT License
@@ -55,6 +55,10 @@ Requires-Dist: PyQt5
55
55
  Requires-Dist: opencv-python-headless
56
56
  Requires-Dist: matplotlib
57
57
  Requires-Dist: requests
58
+ Requires-Dist: flammkuchen
59
+ Requires-Dist: albumentations
60
+ Requires-Dist: wandb
61
+ Requires-Dist: tensorflow
58
62
  Dynamic: license-file
59
63
 
60
64
  <div align="center">
@@ -30,12 +30,11 @@ neuro_sam/napari_utils/anisotropic_scaling.py,sha256=VA6Sd9zEhIAhzjAGto2cOjE9moc
30
30
  neuro_sam/napari_utils/color_utils.py,sha256=Hf5R8f0rh7b9CY1VT72o3tLGfGnnjRREkX8iWsiiu7k,4243
31
31
  neuro_sam/napari_utils/contrasting_color_system.py,sha256=a-lt_3zJLDL9YyIdWJhFDGMYzBb6yH85cV7BNCabbdI,6771
32
32
  neuro_sam/napari_utils/main_widget.py,sha256=yahfPLwmhBt_hImpRykIObzfMwbVZvVJTEKKzMZ11bw,48588
33
- neuro_sam/napari_utils/path_tracing_module.py,sha256=0mMAtrMmtgK_ujMzaWzIguYVDPr8nfzalaTAwgF3NaQ,44062
34
- neuro_sam/napari_utils/punet_widget.py,sha256=FfnC6V_FErczkaQP5y3rp1YBMWPVx6YMI4TxEHah_Vo,16862
33
+ neuro_sam/napari_utils/path_tracing_module.py,sha256=bDhSawWNMfY-Vs-Zdt8XTb90pLkD9jUBRqpbC2_6li0,44070
34
+ neuro_sam/napari_utils/punet_widget.py,sha256=WCAND8YLn4CA_20YW4quwcuJ_xEqPAc7GHrSBk4Anw0,16928
35
35
  neuro_sam/napari_utils/segmentation_model.py,sha256=mHXVjksqEcxHRH5KWp5-hXLEnRHgGhwPUxyUkV8eJGM,34141
36
36
  neuro_sam/napari_utils/segmentation_module.py,sha256=iObM5k8VkARtB_rcqAQGzKJ-PmaAKLeFJD14_Jy6xhs,28732
37
37
  neuro_sam/napari_utils/visualization_module.py,sha256=JtZlBoKlfIwVLa2Sqg7b2KTr07fNlAcwR0M7fHsn2oM,24723
38
- neuro_sam/punet/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
39
38
  neuro_sam/punet/deepd3_model.py,sha256=nGVEqzCPz_E4cFA6QmknW2CffDcjxH7VsdYAyTdAtY0,7509
40
39
  neuro_sam/punet/prob_unet_deepd3.py,sha256=syXNleUVrfYtmVveN9G461oAhumxsijsavps8in4VRw,14698
41
40
  neuro_sam/punet/prob_unet_with_tversky.py,sha256=2dBbO_BEHbhYWBXW7rXQX6s2DnqoTgBKkgk6VkgN-Ds,12845
@@ -43,7 +42,12 @@ neuro_sam/punet/punet_inference.py,sha256=v5ufB2Zz5WfgfFZ5-rDjBEobpr5gy-HKPPWZpC
43
42
  neuro_sam/punet/run_inference.py,sha256=c9ATKWJvhOzNEaww_sUCI5fFS1q0bQ4GYUwNUqxWcwA,5312
44
43
  neuro_sam/punet/unet_blocks.py,sha256=ZRNKay9P3OnJ0PmtKXw_iSgUyRE1DkkGefGXwSbYZGY,3171
45
44
  neuro_sam/punet/utils.py,sha256=ibwcpkqqZ3_3Afz2VYxzplz8_8FWQ5qYQqjJiKS8hIo,1786
46
- neuro_sam-0.1.7.dist-info/licenses/LICENSE,sha256=akmTIN8IuZn3Y7UK_8qVQnyKDWSDcVUwB8RPGNXCojw,1068
45
+ neuro_sam/training/__init__.py,sha256=QISf7Tk0xphF08BWHd-pLI1CMHaEzDsTdmUkgzvwAUM,32
46
+ neuro_sam/training/train_dendrites.py,sha256=TMG4YrrQV0Q784omMziXsB71tNICEmGLYSHr0zI9i6Y,11346
47
+ neuro_sam/training/utils/__init__.py,sha256=4hbcx57NtRu8nryvXQYqmXK4hyUgUYNDz97kCw3Efs8,31
48
+ neuro_sam/training/utils/prompt_generation_dendrites.py,sha256=_ntzXNV1lXPrpInRKaZ5CPpq3akF2IuD1naOXbTC8TU,3201
49
+ neuro_sam/training/utils/stream_dendrites.py,sha256=qS_ZWrhJdW1Sg3RBjoRUJFlCV0u5X1Ns_tjYgJUjWJw,11024
50
+ neuro_sam-0.1.8.dist-info/licenses/LICENSE,sha256=akmTIN8IuZn3Y7UK_8qVQnyKDWSDcVUwB8RPGNXCojw,1068
47
51
  sam2/__init__.py,sha256=uHyh6VzVS4F2box0rPDpN5UmOVKeQNK0CIaTKG9JQZ4,395
48
52
  sam2/automatic_mask_generator.py,sha256=Zt8mbb4UQSMFrjOY8OwbshswOpMhaxAtdn5sTuXUw9c,18461
49
53
  sam2/benchmark.py,sha256=m3o1BriIQuwJAx-3zQ_B0_7YLhN84G28oQSV5sGA3ak,2811
@@ -87,8 +91,8 @@ sam2/utils/__init__.py,sha256=NL2AacVHZOe41zp4kF2-ZGcUCi9zFwh1Eo9spNjN0Ko,197
87
91
  sam2/utils/amg.py,sha256=t7MwkOKvcuBNu4FcjzKv9BpO0av5Zo9itZ8b3WQMpdg,12842
88
92
  sam2/utils/misc.py,sha256=AWAMAcFhzQedcQb7HU2oRc-RqjGrK87K-MsVG21tIKI,13090
89
93
  sam2/utils/transforms.py,sha256=ujpk9GAMYvIJIGpt87QOP88TPtrjL61liDG7DCptEUY,4885
90
- neuro_sam-0.1.7.dist-info/METADATA,sha256=42Pp29GxNSpY1GeNvzeQ3XJyDDsEiMxsCLSSV2Ib-FA,9642
91
- neuro_sam-0.1.7.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
92
- neuro_sam-0.1.7.dist-info/entry_points.txt,sha256=a1JXEgiM_QOPJdV8zvcIS60WAE62MeqgIVY2oSx81FY,162
93
- neuro_sam-0.1.7.dist-info/top_level.txt,sha256=yPbWxFcw79sErTk8zohihUHMK9LL31i3bXir2MrS4OQ,15
94
- neuro_sam-0.1.7.dist-info/RECORD,,
94
+ neuro_sam-0.1.8.dist-info/METADATA,sha256=1vHp2trqfoQL7eP7aia_Xqa8H9f2b6JjsPECvGpsqrk,9746
95
+ neuro_sam-0.1.8.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
96
+ neuro_sam-0.1.8.dist-info/entry_points.txt,sha256=EQg0SmFbnbGGcchHq5ROmhO9pkgby72Y5G5w90WyLZI,220
97
+ neuro_sam-0.1.8.dist-info/top_level.txt,sha256=yPbWxFcw79sErTk8zohihUHMK9LL31i3bXir2MrS4OQ,15
98
+ neuro_sam-0.1.8.dist-info/RECORD,,
@@ -1,6 +1,7 @@
1
1
  [console_scripts]
2
2
  neuro-sam = neuro_sam.plugin:main
3
3
  neuro-sam-download = neuro_sam.utils:download_all_models
4
+ neuro-sam-train = neuro_sam.training.train_dendrites:main
4
5
 
5
6
  [napari.manifest]
6
7
  neuro-sam = neuro_sam:napari.yaml
File without changes