nettracer3d 0.4.4__py3-none-any.whl → 0.4.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nettracer3d/segmenter.py CHANGED
@@ -1,10 +1,16 @@
1
1
  from sklearn.ensemble import RandomForestClassifier
2
2
  import numpy as np
3
- import cupy as cp
4
- import torch
3
+ try:
4
+ import torch
5
+ import cupy as cp
6
+ import cupyx.scipy.ndimage as cpx
7
+ from cuml.ensemble import RandomForestClassifier as cuRandomForestClassifier
8
+ except:
9
+ pass
10
+ import concurrent.futures
5
11
  from concurrent.futures import ThreadPoolExecutor
6
12
  import threading
7
- import cupyx.scipy.ndimage as cpx
13
+ from scipy import ndimage
8
14
 
9
15
 
10
16
  class InteractiveSegmenter:
@@ -16,19 +22,100 @@ class InteractiveSegmenter:
16
22
  if self.use_gpu:
17
23
  print(f"Using GPU: {torch.cuda.get_device_name()}")
18
24
  self.image_gpu = cp.asarray(image_3d)
19
-
20
- self.model = RandomForestClassifier(
21
- n_estimators=100,
22
- n_jobs=-1,
23
- max_depth=None
24
- )
25
+ try:
26
+ self.model = cuRandomForestClassifier(
27
+ n_estimators=100,
28
+ max_depth=None
29
+ )
30
+ except:
31
+ self.model = RandomForestClassifier(
32
+ n_estimators=100,
33
+ n_jobs=-1,
34
+ max_depth=None
35
+ )
36
+
37
+ else:
38
+
39
+ self.model = RandomForestClassifier(
40
+ n_estimators=100,
41
+ n_jobs=-1,
42
+ max_depth=None
43
+ )
44
+
25
45
  self.feature_cache = None
26
46
  self.lock = threading.Lock()
27
47
 
48
+ # Current position attributes
49
+ self.current_z = None
50
+ self.current_x = None
51
+ self.current_y = None
52
+
53
+ self.realtimechunks = None
54
+ self.current_speed = False
55
+
56
+ def compute_deep_feature_maps_cpu(self):
57
+ """Compute feature maps using CPU"""
58
+ features = []
59
+ original_shape = self.image_3d.shape
60
+
61
+ # Gaussian and DoG using scipy
62
+ print("Obtaining gaussians")
63
+ for sigma in [0.5, 1.0, 2.0, 4.0]:
64
+ smooth = ndimage.gaussian_filter(self.image_3d, sigma)
65
+ features.append(smooth)
66
+
67
+ print("Computing local statistics")
68
+ # Local statistics using scipy's convolve
69
+ window_size = 5
70
+ kernel = np.ones((window_size, window_size, window_size)) / (window_size**3)
71
+
72
+ # Local mean
73
+ local_mean = ndimage.convolve(self.image_3d, kernel, mode='reflect')
74
+ features.append(local_mean)
75
+
76
+ # Local variance
77
+ mean = np.mean(self.image_3d)
78
+ local_var = ndimage.convolve((self.image_3d - mean)**2, kernel, mode='reflect')
79
+ features.append(local_var)
80
+
81
+ print("Computing sobel and gradients")
82
+ # Gradient computations using scipy
83
+ gx = ndimage.sobel(self.image_3d, axis=2, mode='reflect')
84
+ gy = ndimage.sobel(self.image_3d, axis=1, mode='reflect')
85
+ gz = ndimage.sobel(self.image_3d, axis=0, mode='reflect')
86
+
87
+ # Gradient magnitude
88
+ gradient_magnitude = np.sqrt(gx**2 + gy**2 + gz**2)
89
+ features.append(gradient_magnitude)
90
+
91
+ print("Computing second-order features")
92
+ # Second-order gradients
93
+ gxx = ndimage.sobel(gx, axis=2, mode='reflect')
94
+ gyy = ndimage.sobel(gy, axis=1, mode='reflect')
95
+ gzz = ndimage.sobel(gz, axis=0, mode='reflect')
96
+
97
+ # Laplacian (sum of second derivatives)
98
+ laplacian = gxx + gyy + gzz
99
+ features.append(laplacian)
100
+
101
+ # Hessian determinant
102
+ hessian_det = gxx * gyy * gzz
103
+ features.append(hessian_det)
104
+
105
+ print("Verifying shapes")
106
+ for i, feat in enumerate(features):
107
+ if feat.shape != original_shape:
108
+ feat_adjusted = np.expand_dims(feat, axis=0)
109
+ if feat_adjusted.shape != original_shape:
110
+ raise ValueError(f"Feature {i} has shape {feat.shape}, expected {original_shape}")
111
+ features[i] = feat_adjusted
112
+
113
+ return np.stack(features, axis=-1)
114
+
28
115
  def compute_feature_maps(self):
29
116
  """Compute all feature maps using GPU acceleration"""
30
- if not self.use_gpu:
31
- return super().compute_feature_maps()
117
+ #if not self.use_gpu:
118
+ #return super().compute_feature_maps()
32
119
 
33
120
  features = []
34
121
  image = self.image_gpu
@@ -85,7 +172,170 @@ class InteractiveSegmenter:
85
172
  # Verify shapes
86
173
  for i, feat in enumerate(features):
87
174
  if feat.shape != original_shape:
88
- raise ValueError(f"Feature {i} has shape {feat.shape}, expected {original_shape}")
175
+ # Create a copy of the feature to modify
176
+ feat_adjusted = np.expand_dims(feat, axis=0)
177
+ if feat_adjusted.shape != original_shape:
178
+ raise ValueError(f"Feature {i} has shape {feat.shape}, expected {original_shape}")
179
+ # Important: Update the original features list with the expanded version
180
+ features[i] = feat_adjusted
181
+
182
+ return np.stack(features, axis=-1)
183
+
184
+ def compute_feature_maps_cpu(self):
185
+ """Compute feature maps using CPU"""
186
+ features = []
187
+ original_shape = self.image_3d.shape
188
+
189
+ # Gaussian smoothing at different scales
190
+ print("Obtaining gaussians")
191
+ for sigma in [0.5, 1.0, 2.0, 4.0]:
192
+ smooth = ndimage.gaussian_filter(self.image_3d, sigma)
193
+ features.append(smooth)
194
+
195
+ print("Obtaining dif of gaussians")
196
+ # Difference of Gaussians
197
+ for (s1, s2) in [(1, 2), (2, 4)]:
198
+ g1 = ndimage.gaussian_filter(self.image_3d, s1)
199
+ g2 = ndimage.gaussian_filter(self.image_3d, s2)
200
+ dog = g1 - g2
201
+ features.append(dog)
202
+
203
+ print("Computing sobel and gradients")
204
+ # Gradient computations using scipy
205
+ gx = ndimage.sobel(self.image_3d, axis=2, mode='reflect') # x direction
206
+ gy = ndimage.sobel(self.image_3d, axis=1, mode='reflect') # y direction
207
+ gz = ndimage.sobel(self.image_3d, axis=0, mode='reflect') # z direction
208
+
209
+ # Gradient magnitude
210
+ print("Computing gradient magnitude")
211
+ gradient_magnitude = np.sqrt(gx**2 + gy**2 + gz**2)
212
+ features.append(gradient_magnitude)
213
+
214
+ # Verify shapes
215
+ print("Verifying shapes")
216
+ for i, feat in enumerate(features):
217
+ if feat.shape != original_shape:
218
+ feat_adjusted = np.expand_dims(feat, axis=0)
219
+ if feat_adjusted.shape != original_shape:
220
+ raise ValueError(f"Feature {i} has shape {feat.shape}, expected {original_shape}")
221
+ features[i] = feat_adjusted
222
+
223
+ return np.stack(features, axis=-1)
224
+
225
+ def compute_deep_feature_maps(self):
226
+ """Compute all feature maps using GPU acceleration"""
227
+ #if not self.use_gpu:
228
+ #return super().compute_feature_maps()
229
+
230
+ features = []
231
+ image = self.image_gpu
232
+ original_shape = self.image_3d.shape
233
+
234
+ # Original features (Gaussians and DoG)
235
+ print("Obtaining gaussians")
236
+ for sigma in [0.5, 1.0, 2.0, 4.0]:
237
+ smooth = cp.asnumpy(self.gaussian_filter_gpu(image, sigma))
238
+ features.append(smooth)
239
+
240
+ print("Computing local statistics")
241
+ image_torch = torch.from_numpy(self.image_3d).cuda()
242
+ image_torch = image_torch.float().unsqueeze(0).unsqueeze(1) # [1, 1, 1, 512, 384]
243
+
244
+ # Create kernel
245
+ window_size = 5
246
+ pad = window_size // 2
247
+
248
+ if image_torch.shape[2] == 1: # Single slice case
249
+ # Squeeze out the z dimension for 2D operations
250
+ image_2d = image_torch.squeeze(2) # Now [1, 1, 512, 384]
251
+ kernel_2d = torch.ones((1, 1, window_size, window_size), device='cuda')
252
+ kernel_2d = kernel_2d / (window_size**2)
253
+
254
+ # 2D padding and convolution
255
+ padded = torch.nn.functional.pad(image_2d,
256
+ (pad, pad, # x dimension
257
+ pad, pad), # y dimension
258
+ mode='reflect')
259
+
260
+ local_mean = torch.nn.functional.conv2d(padded, kernel_2d)
261
+ local_mean = local_mean.unsqueeze(2) # Add z dimension back
262
+ features.append(local_mean.cpu().numpy().squeeze())
263
+
264
+ # Local variance
265
+ mean = torch.mean(image_2d)
266
+ padded_sq = torch.nn.functional.pad((image_2d - mean)**2,
267
+ (pad, pad, pad, pad),
268
+ mode='reflect')
269
+ local_var = torch.nn.functional.conv2d(padded_sq, kernel_2d)
270
+ local_var = local_var.unsqueeze(2) # Add z dimension back
271
+ features.append(local_var.cpu().numpy().squeeze())
272
+ else:
273
+ # Original 3D operations for multi-slice case
274
+ kernel = torch.ones((1, 1, window_size, window_size, window_size), device='cuda')
275
+ kernel = kernel / (window_size**3)
276
+
277
+ padded = torch.nn.functional.pad(image_torch,
278
+ (pad, pad, # x dimension
279
+ pad, pad, # y dimension
280
+ pad, pad), # z dimension
281
+ mode='reflect')
282
+ local_mean = torch.nn.functional.conv3d(padded, kernel)
283
+ features.append(local_mean.cpu().numpy().squeeze())
284
+
285
+ mean = torch.mean(image_torch)
286
+ padded_sq = torch.nn.functional.pad((image_torch - mean)**2,
287
+ (pad, pad, pad, pad, pad, pad),
288
+ mode='reflect')
289
+ local_var = torch.nn.functional.conv3d(padded_sq, kernel)
290
+ features.append(local_var.cpu().numpy().squeeze())
291
+
292
+ # Original gradient computations
293
+ print("Computing sobel and gradients")
294
+ kernel_size = 3
295
+ padding = kernel_size // 2
296
+ pad = torch.nn.functional.pad(image_torch, (padding,)*6, mode='replicate')
297
+
298
+ sobel_x = torch.tensor([-1, 0, 1], device='cuda').float().view(1,1,1,1,3)
299
+ sobel_y = torch.tensor([-1, 0, 1], device='cuda').float().view(1,1,1,3,1)
300
+ sobel_z = torch.tensor([-1, 0, 1], device='cuda').float().view(1,1,3,1,1)
301
+
302
+ gx = torch.nn.functional.conv3d(pad, sobel_x, padding=0)[:,:,:original_shape[0],:original_shape[1],:original_shape[2]]
303
+ gy = torch.nn.functional.conv3d(pad, sobel_y, padding=0)[:,:,:original_shape[0],:original_shape[1],:original_shape[2]]
304
+ gz = torch.nn.functional.conv3d(pad, sobel_z, padding=0)[:,:,:original_shape[0],:original_shape[1],:original_shape[2]]
305
+
306
+ gradient_magnitude = torch.sqrt(gx**2 + gy**2 + gz**2)
307
+ features.append(gradient_magnitude.cpu().numpy().squeeze())
308
+
309
+ # Second-order gradients
310
+ print("Computing second-order features")
311
+ gxx = torch.nn.functional.conv3d(gx, sobel_x, padding=padding)
312
+ gyy = torch.nn.functional.conv3d(gy, sobel_y, padding=padding)
313
+ gzz = torch.nn.functional.conv3d(gz, sobel_z, padding=padding)
314
+
315
+ # Get minimum size in each dimension
316
+ min_size_0 = min(gxx.size(2), gyy.size(2), gzz.size(2))
317
+ min_size_1 = min(gxx.size(3), gyy.size(3), gzz.size(3))
318
+ min_size_2 = min(gxx.size(4), gyy.size(4), gzz.size(4))
319
+
320
+ # Crop to smallest common size
321
+ gxx = gxx[:, :, :min_size_0, :min_size_1, :min_size_2]
322
+ gyy = gyy[:, :, :min_size_0, :min_size_1, :min_size_2]
323
+ gzz = gzz[:, :, :min_size_0, :min_size_1, :min_size_2]
324
+
325
+ laplacian = gxx + gyy + gzz # Second derivatives in each direction
326
+ features.append(laplacian.cpu().numpy().squeeze())
327
+
328
+ # Now they should have matching dimensions for multiplication
329
+ hessian_det = gxx * gyy * gzz
330
+ features.append(hessian_det.cpu().numpy().squeeze())
331
+
332
+ print("Verifying shapes")
333
+ for i, feat in enumerate(features):
334
+ if feat.shape != original_shape:
335
+ feat_adjusted = np.expand_dims(feat, axis=0)
336
+ if feat_adjusted.shape != original_shape:
337
+ raise ValueError(f"Feature {i} has shape {feat.shape}, expected {original_shape}")
338
+ features[i] = feat_adjusted
89
339
 
90
340
  return np.stack(features, axis=-1)
91
341
 
@@ -113,6 +363,33 @@ class InteractiveSegmenter:
113
363
  self.model.fit(X, y)
114
364
  self.patterns = []
115
365
 
366
+ def process_chunk_GPU(self, chunk_coords):
367
+ """Process a chunk of coordinates using GPU acceleration"""
368
+ coords = np.array(chunk_coords)
369
+ z, y, x = coords.T
370
+
371
+ # Extract features
372
+ features = self.feature_cache[z, y, x]
373
+
374
+ if self.use_gpu:
375
+ # Move to GPU
376
+ features_gpu = cp.array(features)
377
+
378
+ # Predict on GPU
379
+ predictions = self.model.predict(features_gpu)
380
+ predictions = cp.asnumpy(predictions)
381
+ else:
382
+ predictions = self.model.predict(features)
383
+
384
+ # Split results
385
+ foreground_mask = predictions == 1
386
+ background_mask = ~foreground_mask
387
+
388
+ foreground = set(map(tuple, coords[foreground_mask]))
389
+ background = set(map(tuple, coords[background_mask]))
390
+
391
+ return foreground, background
392
+
116
393
  def process_chunk(self, chunk_coords):
117
394
  """Process a chunk of coordinates"""
118
395
  features = [self.feature_cache[z, y, x] for z, y, x in chunk_coords]
@@ -128,32 +405,62 @@ class InteractiveSegmenter:
128
405
 
129
406
  return foreground, background
130
407
 
131
- def segment_volume(self, chunk_size=32):
132
- """Segment volume using parallel processing of chunks"""
408
+ def segment_volume(self, chunk_size=32, gpu = False):
409
+ """Segment volume using parallel processing of chunks with vectorized chunk creation"""
410
+
411
+ try:
412
+ from cuml.ensemble import RandomForestClassifier as cuRandomForestClassifier
413
+ except:
414
+ print("Cannot find cuml, using CPU to segment instead...")
415
+ gpu = False
416
+
417
+
133
418
  if self.feature_cache is None:
134
419
  with self.lock:
135
420
  if self.feature_cache is None:
136
421
  self.feature_cache = self.compute_feature_maps()
137
422
 
138
- # Create chunks of coordinates
423
+ # Calculate number of chunks in each dimension
424
+ z_chunks = (self.image_3d.shape[0] + chunk_size - 1) // chunk_size
425
+ y_chunks = (self.image_3d.shape[1] + chunk_size - 1) // chunk_size
426
+ x_chunks = (self.image_3d.shape[2] + chunk_size - 1) // chunk_size
427
+
428
+ # Create start indices for all chunks at once
429
+ chunk_starts = np.array(np.meshgrid(
430
+ np.arange(z_chunks) * chunk_size,
431
+ np.arange(y_chunks) * chunk_size,
432
+ np.arange(x_chunks) * chunk_size,
433
+ indexing='ij'
434
+ )).reshape(3, -1).T
435
+
139
436
  chunks = []
140
- for z in range(0, self.image_3d.shape[0], chunk_size):
141
- for y in range(0, self.image_3d.shape[1], chunk_size):
142
- for x in range(0, self.image_3d.shape[2], chunk_size):
143
- chunk_coords = [
144
- (zz, yy, xx)
145
- for zz in range(z, min(z + chunk_size, self.image_3d.shape[0]))
146
- for yy in range(y, min(y + chunk_size, self.image_3d.shape[1]))
147
- for xx in range(x, min(x + chunk_size, self.image_3d.shape[2]))
148
- ]
149
- chunks.append(chunk_coords)
437
+ for z_start, y_start, x_start in chunk_starts:
438
+ z_end = min(z_start + chunk_size, self.image_3d.shape[0])
439
+ y_end = min(y_start + chunk_size, self.image_3d.shape[1])
440
+ x_end = min(x_start + chunk_size, self.image_3d.shape[2])
441
+
442
+ # Create coordinates for this chunk efficiently
443
+ coords = np.stack(np.meshgrid(
444
+ np.arange(z_start, z_end),
445
+ np.arange(y_start, y_end),
446
+ np.arange(x_start, x_end),
447
+ indexing='ij'
448
+ )).reshape(3, -1).T
449
+
450
+ chunks.append(list(map(tuple, coords)))
150
451
 
151
452
  foreground_coords = set()
152
453
  background_coords = set()
153
454
 
154
- # Process chunks in parallel
155
455
  with ThreadPoolExecutor() as executor:
156
- futures = [executor.submit(self.process_chunk, chunk) for chunk in chunks]
456
+ if gpu:
457
+ try:
458
+ futures = [executor.submit(self.process_chunk_GPU, chunk) for chunk in chunks]
459
+ except:
460
+ futures = [executor.submit(self.process_chunk, chunk) for chunk in chunks]
461
+
462
+ else:
463
+ futures = [executor.submit(self.process_chunk, chunk) for chunk in chunks]
157
464
 
158
465
  for i, future in enumerate(futures):
159
466
  fore, back = future.result()
@@ -164,127 +471,223 @@ class InteractiveSegmenter:
164
471
 
165
472
  return foreground_coords, background_coords
166
473
 
167
- def cleanup(self):
168
- """Clean up GPU memory"""
169
- if self.use_gpu:
170
- cp.get_default_memory_pool().free_all_blocks()
171
- torch.cuda.empty_cache()
474
+ def update_position(self, z=None, x=None, y=None):
475
+ """Update current position for chunk prioritization"""
476
+ self.current_z = z
477
+ self.current_x = x
478
+ self.current_y = y
172
479
 
173
- def train_batch(self, foreground_array, background_array):
174
- """Train directly on foreground and background arrays"""
175
- if self.feature_cache is None:
176
- with self.lock:
177
- if self.feature_cache is None:
178
- self.feature_cache = self.compute_feature_maps()
480
+
481
+ def get_realtime_chunks(self, chunk_size = 32):
482
+ print("Computing some overhead...")
483
+
484
+
485
+
486
+ # Determine if we need to chunk XY planes
487
+ small_dims = (self.image_3d.shape[1] <= chunk_size and
488
+ self.image_3d.shape[2] <= chunk_size)
489
+ few_z = self.image_3d.shape[0] <= 100 # arbitrary threshold
179
490
 
180
- # Get foreground coordinates and features
181
- z_fore, y_fore, x_fore = np.where(foreground_array > 0)
182
- foreground_features = self.feature_cache[z_fore, y_fore, x_fore]
491
+ # If small enough, each Z is one chunk
492
+ if small_dims and few_z:
493
+ chunk_size_xy = max(self.image_3d.shape[1], self.image_3d.shape[2])
494
+ else:
495
+ chunk_size_xy = chunk_size
183
496
 
184
- # Get background coordinates and features
185
- z_back, y_back, x_back = np.where(background_array > 0)
186
- background_features = self.feature_cache[z_back, y_back, x_back]
497
+ # Calculate chunks for XY plane
498
+ y_chunks = (self.image_3d.shape[1] + chunk_size_xy - 1) // chunk_size_xy
499
+ x_chunks = (self.image_3d.shape[2] + chunk_size_xy - 1) // chunk_size_xy
187
500
 
188
- # Combine features and labels
189
- X = np.vstack([foreground_features, background_features])
190
- y = np.hstack([np.ones(len(z_fore)), np.zeros(len(z_back))])
501
+ # Populate chunk dictionary
502
+ chunk_dict = {}
191
503
 
192
- # Train the model
193
- self.model.fit(X, y)
194
-
195
- print("Done")
504
+ # Create chunks for each Z plane
505
+ for z in range(self.image_3d.shape[0]):
506
+ if small_dims:
507
+ # One chunk per Z
508
+ coords = np.stack(np.meshgrid(
509
+ [z],
510
+ np.arange(self.image_3d.shape[1]),
511
+ np.arange(self.image_3d.shape[2]),
512
+ indexing='ij'
513
+ )).reshape(3, -1).T
514
+
515
+ chunk_dict[(z, 0, 0)] = {
516
+ 'coords': list(map(tuple, coords)),
517
+ 'processed': False,
518
+ 'z': z
519
+ }
520
+ else:
521
+ # Multiple chunks per Z
522
+ for y_chunk in range(y_chunks):
523
+ for x_chunk in range(x_chunks):
524
+ y_start = y_chunk * chunk_size_xy
525
+ x_start = x_chunk * chunk_size_xy
526
+ y_end = min(y_start + chunk_size_xy, self.image_3d.shape[1])
527
+ x_end = min(x_start + chunk_size_xy, self.image_3d.shape[2])
528
+
529
+ coords = np.stack(np.meshgrid(
530
+ [z],
531
+ np.arange(y_start, y_end),
532
+ np.arange(x_start, x_end),
533
+ indexing='ij'
534
+ )).reshape(3, -1).T
535
+
536
+ chunk_dict[(z, y_start, x_start)] = {
537
+ 'coords': list(map(tuple, coords)),
538
+ 'processed': False,
539
+ 'z': z
540
+ }
196
541
 
542
+ self.realtimechunks = chunk_dict
197
543
 
544
+ print("Ready!")
198
545
 
199
546
 
547
+ def segment_volume_realtime(self, gpu = False):
200
548
 
549
+ try:
550
+ from cuml.ensemble import RandomForestClassifier as cuRandomForestClassifier
551
+ except:
552
+ print("Cannot find cuml, using CPU to segment instead...")
553
+ gpu = False
201
554
 
202
555
 
556
+ if self.realtimechunks is None:
557
+ self.get_realtime_chunks()
558
+ else:
559
+ for chunk_pos in self.realtimechunks: # chunk_pos is the (z, y_start, x_start) tuple
560
+ self.realtimechunks[chunk_pos]['processed'] = False
203
561
 
562
+ chunk_dict = self.realtimechunks
204
563
 
205
- def segment_volume_subprocess(self, chunk_size=32, current_z=None, current_x=None, current_y=None):
206
- """
207
- Segment volume prioritizing chunks near user location.
208
- Returns chunks as they're processed.
209
- """
210
- if self.feature_cache is None:
211
- with self.lock:
212
- if self.feature_cache is None:
213
- self.feature_cache = self.compute_feature_maps()
564
+
565
+ def get_nearest_unprocessed_chunk(self):
566
+ """Get nearest unprocessed chunk prioritizing current Z"""
567
+ curr_z = self.current_z if self.current_z is not None else self.image_3d.shape[0] // 2
568
+ curr_y = self.current_x if self.current_x is not None else self.image_3d.shape[1] // 2
569
+ curr_x = self.current_y if self.current_y is not None else self.image_3d.shape[2] // 2
214
570
 
215
- # Create chunks with position information
216
- chunks_info = []
217
- for z in range(0, self.image_3d.shape[0], chunk_size):
218
- for y in range(0, self.image_3d.shape[1], chunk_size):
219
- for x in range(0, self.image_3d.shape[2], chunk_size):
220
- chunk_coords = [
221
- (zz, yy, xx)
222
- for zz in range(z, min(z + chunk_size, self.image_3d.shape[0]))
223
- for yy in range(y, min(y + chunk_size, self.image_3d.shape[1]))
224
- for xx in range(x, min(x + chunk_size, self.image_3d.shape[2]))
225
- ]
226
-
227
- # Store chunk with its corner position
228
- chunks_info.append({
229
- 'coords': chunk_coords,
230
- 'corner': (z, y, x),
231
- 'processed': False
232
- })
233
-
234
- def get_chunk_priority(chunk):
235
- """Calculate priority based on distance from user position"""
236
- z, y, x = chunk['corner']
237
- priority = 0
238
-
239
- # Priority based on Z distance (always used)
240
- if current_z is not None:
241
- priority += abs(z - current_z)
242
-
243
- # Add X/Y distance if provided
244
- if current_x is not None and current_y is not None:
245
- xy_distance = ((x - current_x) ** 2 + (y - current_y) ** 2) ** 0.5
246
- priority += xy_distance
247
-
248
- return priority
571
+ # First try to find chunks at current Z
572
+ current_z_chunks = [(pos, info) for pos, info in chunk_dict.items()
573
+ if info['z'] == curr_z and not info['processed']]
574
+
575
+ if current_z_chunks:
576
+ # Find nearest chunk in current Z plane
577
+ nearest = min(current_z_chunks,
578
+ key=lambda x: ((x[0][1] - curr_y) ** 2 +
579
+ (x[0][2] - curr_x) ** 2))
580
+ return nearest[0]
581
+
582
+ # If no chunks at current Z, find nearest Z with available chunks
583
+ available_z = sorted(
584
+ [(pos[0], pos) for pos, info in chunk_dict.items()
585
+ if not info['processed']],
586
+ key=lambda x: abs(x[0] - curr_z)
587
+ )
588
+
589
+ if available_z:
590
+ target_z = available_z[0][0]
591
+ # Find nearest chunk in target Z plane
592
+ z_chunks = [(pos, info) for pos, info in chunk_dict.items()
593
+ if info['z'] == target_z and not info['processed']]
594
+ nearest = min(z_chunks,
595
+ key=lambda x: ((x[0][1] - curr_y) ** 2 +
596
+ (x[0][2] - curr_x) ** 2))
597
+ return nearest[0]
598
+
599
+ return None
600
+
249
601
 
250
- with ThreadPoolExecutor() as executor:
251
- futures = {} # Track active futures
252
-
253
- while True:
254
- # Sort unprocessed chunks by priority
255
- unprocessed_chunks = [c for c in chunks_info if not c['processed']]
256
- if not unprocessed_chunks:
257
- break
258
-
259
- # Sort by distance from current position
260
- unprocessed_chunks.sort(key=get_chunk_priority)
602
+ with ThreadPoolExecutor() as executor:
603
+ futures = {}
604
+ import multiprocessing
605
+ total_cores = multiprocessing.cpu_count()
606
+ #available_workers = max(1, min(4, total_cores // 2)) # Use half cores, max of 4
607
+ available_workers = 1
608
+
609
+ while True:
610
+ # Find nearest unprocessed chunk using class attributes
611
+ chunk_idx = get_nearest_unprocessed_chunk(self) # Pass self
612
+ if chunk_idx is None:
613
+ break
261
614
 
262
- # Submit new chunks to replace completed ones
263
- while len(futures) < executor._max_workers and unprocessed_chunks:
264
- chunk = unprocessed_chunks.pop(0)
615
+ while (len(futures) < available_workers and
616
+ (chunk_idx := get_nearest_unprocessed_chunk(self))): # Pass self
617
+ chunk = chunk_dict[chunk_idx]
618
+ if gpu:
619
+ try:
620
+ futures = [executor.submit(self.process_chunk_GPU, chunk) for chunk in chunks]
621
+ except:
622
+ futures = [executor.submit(self.process_chunk, chunk) for chunk in chunks]
623
+ else:
265
624
  future = executor.submit(self.process_chunk, chunk['coords'])
266
- futures[future] = chunk
267
- chunk['processed'] = True
268
-
269
- # Check completed futures
270
- done, _ = concurrent.futures.wait(
271
- futures.keys(),
272
- timeout=0.1,
273
- return_when=concurrent.futures.FIRST_COMPLETED
274
- )
275
-
276
- # Process completed chunks
277
- for future in done:
278
- chunk = futures[future]
279
- fore, back = future.result()
280
-
281
- # Yield chunk results with position information
282
- yield {
283
- 'foreground': fore,
284
- 'background': back,
285
- 'corner': chunk['corner'],
286
- 'size': chunk_size
287
- }
288
-
289
- del futures[future]
625
+
626
+ futures[future] = chunk_idx
627
+ chunk['processed'] = True
628
+
629
+ # Check completed futures
630
+ done, _ = concurrent.futures.wait(
631
+ futures.keys(),
632
+ timeout=0.1,
633
+ return_when=concurrent.futures.FIRST_COMPLETED
634
+ )
635
+
636
+ # Process completed chunks
637
+ for future in done:
638
+ fore, back = future.result()
639
+ del futures[future]
640
+ yield fore, back
641
+
642
+ def cleanup(self):
643
+ """Clean up GPU memory"""
644
+ if self.use_gpu:
645
+ cp.get_default_memory_pool().free_all_blocks()
646
+ torch.cuda.empty_cache()
647
+
648
+ def train_batch(self, foreground_array, speed = True, use_gpu = False):
649
+ """Train directly on foreground and background arrays"""
650
+
651
+ if self.current_speed != speed:
652
+ self.feature_cache = None
653
+
654
+ if self.feature_cache is None:
655
+ with self.lock:
656
+ if self.feature_cache is None and speed:
657
+ if use_gpu:
658
+ self.feature_cache = self.compute_feature_maps()
659
+ else:
660
+ self.feature_cache = self.compute_feature_maps_cpu()
661
+
662
+ elif self.feature_cache is None and not speed:
663
+ if use_gpu:
664
+
665
+ self.feature_cache = self.compute_deep_feature_maps()
666
+ else:
667
+ self.feature_cache = self.compute_deep_feature_maps_cpu()
668
+
669
+
670
+ try:
671
+ # Get foreground coordinates and features
672
+ z_fore, y_fore, x_fore = np.where(foreground_array == 1)
673
+ foreground_features = self.feature_cache[z_fore, y_fore, x_fore]
674
+
675
+ # Get background coordinates and features
676
+ z_back, y_back, x_back = np.where(foreground_array == 2)
677
+ background_features = self.feature_cache[z_back, y_back, x_back]
678
+
679
+ # Combine features and labels
680
+ X = np.vstack([foreground_features, background_features])
681
+ y = np.hstack([np.ones(len(z_fore)), np.zeros(len(z_back))])
682
+
683
+ # Train the model
684
+ self.model.fit(X, y)
685
+
686
+ self.current_speed = speed
687
+ except:
688
+ print("Features maps computed, but no segmentation examples were provided so the model was not trained")
689
+
690
+
691
+ print("Done")
692
+
290
693