nettracer3d 0.4.3__tar.gz → 0.4.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (26) hide show
  1. {nettracer3d-0.4.3/src/nettracer3d.egg-info → nettracer3d-0.4.4}/PKG-INFO +2 -2
  2. {nettracer3d-0.4.3 → nettracer3d-0.4.4}/README.md +1 -1
  3. {nettracer3d-0.4.3 → nettracer3d-0.4.4}/pyproject.toml +1 -1
  4. {nettracer3d-0.4.3 → nettracer3d-0.4.4}/src/nettracer3d/nettracer.py +6 -2
  5. {nettracer3d-0.4.3 → nettracer3d-0.4.4}/src/nettracer3d/nettracer_gui.py +539 -14
  6. nettracer3d-0.4.4/src/nettracer3d/segmenter.py +290 -0
  7. {nettracer3d-0.4.3 → nettracer3d-0.4.4}/src/nettracer3d/smart_dilate.py +44 -3
  8. {nettracer3d-0.4.3 → nettracer3d-0.4.4/src/nettracer3d.egg-info}/PKG-INFO +2 -2
  9. {nettracer3d-0.4.3 → nettracer3d-0.4.4}/src/nettracer3d.egg-info/SOURCES.txt +1 -0
  10. {nettracer3d-0.4.3 → nettracer3d-0.4.4}/LICENSE +0 -0
  11. {nettracer3d-0.4.3 → nettracer3d-0.4.4}/setup.cfg +0 -0
  12. {nettracer3d-0.4.3 → nettracer3d-0.4.4}/src/nettracer3d/__init__.py +0 -0
  13. {nettracer3d-0.4.3 → nettracer3d-0.4.4}/src/nettracer3d/community_extractor.py +0 -0
  14. {nettracer3d-0.4.3 → nettracer3d-0.4.4}/src/nettracer3d/hub_getter.py +0 -0
  15. {nettracer3d-0.4.3 → nettracer3d-0.4.4}/src/nettracer3d/modularity.py +0 -0
  16. {nettracer3d-0.4.3 → nettracer3d-0.4.4}/src/nettracer3d/morphology.py +0 -0
  17. {nettracer3d-0.4.3 → nettracer3d-0.4.4}/src/nettracer3d/network_analysis.py +0 -0
  18. {nettracer3d-0.4.3 → nettracer3d-0.4.4}/src/nettracer3d/network_draw.py +0 -0
  19. {nettracer3d-0.4.3 → nettracer3d-0.4.4}/src/nettracer3d/node_draw.py +0 -0
  20. {nettracer3d-0.4.3 → nettracer3d-0.4.4}/src/nettracer3d/proximity.py +0 -0
  21. {nettracer3d-0.4.3 → nettracer3d-0.4.4}/src/nettracer3d/run.py +0 -0
  22. {nettracer3d-0.4.3 → nettracer3d-0.4.4}/src/nettracer3d/simple_network.py +0 -0
  23. {nettracer3d-0.4.3 → nettracer3d-0.4.4}/src/nettracer3d.egg-info/dependency_links.txt +0 -0
  24. {nettracer3d-0.4.3 → nettracer3d-0.4.4}/src/nettracer3d.egg-info/entry_points.txt +0 -0
  25. {nettracer3d-0.4.3 → nettracer3d-0.4.4}/src/nettracer3d.egg-info/requires.txt +0 -0
  26. {nettracer3d-0.4.3 → nettracer3d-0.4.4}/src/nettracer3d.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: nettracer3d
3
- Version: 0.4.3
3
+ Version: 0.4.4
4
4
  Summary: Scripts for intializing and analyzing networks from segmentations of three dimensional images.
5
5
  Author-email: Liam McLaughlin <boom2449@gmail.com>
6
6
  Project-URL: User_Manual, https://drive.google.com/drive/folders/1fTkz3n4LN9_VxKRKC8lVQSlrz_wq0bVn?usp=drive_link
@@ -32,7 +32,7 @@ Requires-Dist: cupy-cuda12x; extra == "cuda12"
32
32
  Provides-Extra: cupy
33
33
  Requires-Dist: cupy; extra == "cupy"
34
34
 
35
- NetTracer3D is a python package developed for both 2D and 3D analysis of microscopic images in the .tif file format. It supports generation of 3D networks showing the relationships between objects (or nodes) in three dimensional space, either based on their own proximity or connectivity via connecting objects such as nerves or blood vessels. In addition to these functionalities are several advanced 3D data processing algorithms, such as labeling of branched structures or abstraction of branched structures into networks. Note that nettracer3d uses segmented data, which can be segmented from other softwares such as ImageJ and imported into NetTracer3D, although it does offer its own segmentation via intensity or volumetric thresholding. NetTracer3D currently has a fully functional GUI. To use the GUI, after installing the nettracer3d package via pip, enter the command 'nettracer3d' in your command prompt:
35
+ NetTracer3D is a python package developed for both 2D and 3D analysis of microscopic images in the .tif file format. It supports generation of 3D networks showing the relationships between objects (or nodes) in three dimensional space, either based on their own proximity or connectivity via connecting objects such as nerves or blood vessels. In addition to these functionalities are several advanced 3D data processing algorithms, such as labeling of branched structures or abstraction of branched structures into networks. Note that nettracer3d uses segmented data, which can be segmented from other softwares such as ImageJ and imported into NetTracer3D, although it does offer its own segmentation via intensity and volumetric thresholding, or random forest machine learning segmentation. NetTracer3D currently has a fully functional GUI. To use the GUI, after installing the nettracer3d package via pip, enter the command 'nettracer3d' in your command prompt:
36
36
 
37
37
 
38
38
  This gui is built from the PyQt6 package and therefore may not function on dockers or virtual envs that are unable to support PyQt6 displays. More advanced documentation (especially for the GUI) is coming down the line, but for now please see: https://drive.google.com/drive/folders/1fTkz3n4LN9_VxKRKC8lVQSlrz_wq0bVn?usp=drive_link
@@ -1,4 +1,4 @@
1
- NetTracer3D is a python package developed for both 2D and 3D analysis of microscopic images in the .tif file format. It supports generation of 3D networks showing the relationships between objects (or nodes) in three dimensional space, either based on their own proximity or connectivity via connecting objects such as nerves or blood vessels. In addition to these functionalities are several advanced 3D data processing algorithms, such as labeling of branched structures or abstraction of branched structures into networks. Note that nettracer3d uses segmented data, which can be segmented from other softwares such as ImageJ and imported into NetTracer3D, although it does offer its own segmentation via intensity or volumetric thresholding. NetTracer3D currently has a fully functional GUI. To use the GUI, after installing the nettracer3d package via pip, enter the command 'nettracer3d' in your command prompt:
1
+ NetTracer3D is a python package developed for both 2D and 3D analysis of microscopic images in the .tif file format. It supports generation of 3D networks showing the relationships between objects (or nodes) in three dimensional space, either based on their own proximity or connectivity via connecting objects such as nerves or blood vessels. In addition to these functionalities are several advanced 3D data processing algorithms, such as labeling of branched structures or abstraction of branched structures into networks. Note that nettracer3d uses segmented data, which can be segmented from other softwares such as ImageJ and imported into NetTracer3D, although it does offer its own segmentation via intensity and volumetric thresholding, or random forest machine learning segmentation. NetTracer3D currently has a fully functional GUI. To use the GUI, after installing the nettracer3d package via pip, enter the command 'nettracer3d' in your command prompt:
2
2
 
3
3
 
4
4
  This gui is built from the PyQt6 package and therefore may not function on dockers or virtual envs that are unable to support PyQt6 displays. More advanced documentation (especially for the GUI) is coming down the line, but for now please see: https://drive.google.com/drive/folders/1fTkz3n4LN9_VxKRKC8lVQSlrz_wq0bVn?usp=drive_link
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "nettracer3d"
3
- version = "0.4.3"
3
+ version = "0.4.4"
4
4
  authors = [
5
5
  { name="Liam McLaughlin", email="boom2449@gmail.com" },
6
6
  ]
@@ -1836,6 +1836,7 @@ def watershed(image, directory = None, proportion = 0.1, GPU = True, smallest_ra
1836
1836
  gotoexcept = 1/0
1837
1837
 
1838
1838
  except (cp.cuda.memory.OutOfMemoryError, ZeroDivisionError) as e:
1839
+
1839
1840
  if predownsample is None:
1840
1841
  down_factor = smart_dilate.catch_memory(e) #Obtain downsample amount based on memory missing
1841
1842
  else:
@@ -1866,8 +1867,12 @@ def watershed(image, directory = None, proportion = 0.1, GPU = True, smallest_ra
1866
1867
 
1867
1868
  labels, _ = label_objects(distance)
1868
1869
 
1870
+ if len(labels.shape) ==2:
1871
+ labels = np.expand_dims(labels, axis = 0)
1872
+
1869
1873
  del distance
1870
1874
 
1875
+
1871
1876
  if labels.shape[1] < original_shape[1]: #If downsample was used, upsample output
1872
1877
  labels = upsample_with_padding(labels, downsample_needed, original_shape)
1873
1878
  labels = labels * old_mask
@@ -1876,8 +1881,7 @@ def watershed(image, directory = None, proportion = 0.1, GPU = True, smallest_ra
1876
1881
  labels = smart_dilate.smart_label(image, labels, GPU = GPU, predownsample = predownsample2)
1877
1882
 
1878
1883
  if directory is None:
1879
- tifffile.imwrite("Watershed_output.tif", labels)
1880
- print("Watershed saved to 'Watershed_output.tif'")
1884
+ pass
1881
1885
  else:
1882
1886
  tifffile.imwrite(f"{directory}/Watershed_output.tif", labels)
1883
1887
  print(f"Watershed saved to {directory}/'Watershed_output.tif'")
@@ -18,12 +18,13 @@ from nettracer3d import proximity as pxt
18
18
  from matplotlib.colors import LinearSegmentedColormap
19
19
  from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
20
20
  import pandas as pd
21
- from PyQt6.QtGui import (QFont, QCursor, QColor)
21
+ from PyQt6.QtGui import (QFont, QCursor, QColor, QPixmap, QPainter, QPen)
22
22
  import tifffile
23
23
  import copy
24
24
  import multiprocessing as mp
25
25
  from concurrent.futures import ThreadPoolExecutor
26
26
  from functools import partial
27
+ from nettracer3d import segmenter
27
28
 
28
29
 
29
30
  class ImageViewerWindow(QMainWindow):
@@ -106,11 +107,21 @@ class ImageViewerWindow(QMainWindow):
106
107
  self.zoom_mode = False
107
108
  self.original_xlim = None
108
109
  self.original_ylim = None
110
+ self.zoom_changed = False
109
111
 
110
112
  # Pan mode state
111
113
  self.pan_mode = False
112
114
  self.panning = False
113
115
  self.pan_start = None
116
+
117
+ #For ML segmenting mode
118
+ self.brush_mode = False
119
+ self.painting = False
120
+ self.foreground = True
121
+ self.machine_window = None
122
+ self.brush_size = 1 # Start with 1 pixel
123
+ self.min_brush_size = 1
124
+ self.max_brush_size = 10
114
125
 
115
126
  # Store brightness/contrast values for each channel
116
127
  self.channel_brightness = [{
@@ -240,6 +251,8 @@ class ImageViewerWindow(QMainWindow):
240
251
  self.ax = self.figure.add_subplot(111)
241
252
  left_layout.addWidget(self.canvas)
242
253
 
254
+ self.canvas.mpl_connect('scroll_event', self.on_mpl_scroll)
255
+
243
256
 
244
257
  left_layout.addWidget(control_panel)
245
258
 
@@ -348,6 +361,7 @@ class ImageViewerWindow(QMainWindow):
348
361
  self.canvas.mpl_connect('button_press_event', self.on_mouse_press)
349
362
  self.canvas.mpl_connect('button_release_event', self.on_mouse_release)
350
363
  self.canvas.mpl_connect('motion_notify_event', self.on_mouse_move)
364
+
351
365
  #self.canvas.mpl_connect('button_press_event', self.on_mouse_click)
352
366
 
353
367
  # Initialize measurement points tracking
@@ -1423,9 +1437,16 @@ class ImageViewerWindow(QMainWindow):
1423
1437
  if self.zoom_mode:
1424
1438
  self.pan_button.setChecked(False)
1425
1439
  self.pan_mode = False
1440
+ self.brush_mode = False
1441
+ if self.machine_window is not None:
1442
+ self.machine_window.silence_button()
1426
1443
  self.canvas.setCursor(Qt.CursorShape.CrossCursor)
1427
1444
  else:
1428
- self.canvas.setCursor(Qt.CursorShape.ArrowCursor)
1445
+ if self.machine_window is None:
1446
+ self.canvas.setCursor(Qt.CursorShape.ArrowCursor)
1447
+ else:
1448
+ self.machine_window.toggle_brush_button()
1449
+
1429
1450
 
1430
1451
  def toggle_pan_mode(self):
1431
1452
  """Toggle pan mode on/off."""
@@ -1433,14 +1454,132 @@ class ImageViewerWindow(QMainWindow):
1433
1454
  if self.pan_mode:
1434
1455
  self.zoom_button.setChecked(False)
1435
1456
  self.zoom_mode = False
1457
+ self.brush_mode = False
1458
+ if self.machine_window is not None:
1459
+ self.machine_window.silence_button()
1436
1460
  self.canvas.setCursor(Qt.CursorShape.OpenHandCursor)
1437
1461
  else:
1438
- self.canvas.setCursor(Qt.CursorShape.ArrowCursor)
1462
+ if self.machine_window is None:
1463
+ self.canvas.setCursor(Qt.CursorShape.ArrowCursor)
1464
+ else:
1465
+ self.machine_window.toggle_brush_button()
1466
+
1467
+
1468
+
1469
+ def on_mpl_scroll(self, event):
1470
+ """Handle matplotlib canvas scroll events"""
1471
+ #Wheel events
1472
+ if self.brush_mode and event.inaxes == self.ax:
1473
+ # Check if Ctrl is pressed
1474
+ if event.guiEvent.modifiers() & Qt.ShiftModifier:
1475
+ pass
1476
+
1477
+ elif event.guiEvent.modifiers() & Qt.ControlModifier:
1478
+ # Brush size adjustment code...
1479
+ step = 1 if event.button == 'up' else -1
1480
+ new_size = self.brush_size + step
1481
+
1482
+ if new_size < self.min_brush_size:
1483
+ new_size = self.min_brush_size
1484
+ elif new_size > self.max_brush_size:
1485
+ new_size = self.max_brush_size
1486
+
1487
+ self.brush_size = new_size
1488
+ self.update_brush_cursor()
1489
+ event.guiEvent.accept()
1490
+ return
1491
+
1492
+ # General scrolling code outside the brush mode condition
1493
+ step = 1 if event.button == 'up' else -1
1494
+
1495
+ if event.guiEvent.modifiers() & Qt.ShiftModifier:
1496
+ if event.guiEvent.modifiers() & Qt.ControlModifier:
1497
+ step = step * 3
1498
+ if (self.current_slice + step) < 0 or (self.current_slice + step) > self.slice_slider.maximum():
1499
+ return
1500
+
1501
+ self.current_slice = self.current_slice + step
1502
+ self.slice_slider.setValue(self.current_slice + step)
1503
+
1504
+ current_xlim = self.ax.get_xlim() if hasattr(self, 'ax') and self.ax.get_xlim() != (0, 1) else None
1505
+ current_ylim = self.ax.get_ylim() if hasattr(self, 'ax') and self.ax.get_ylim() != (0, 1) else None
1506
+
1507
+ self.update_display(preserve_zoom=(current_xlim, current_ylim))
1508
+
1509
+ def keyPressEvent(self, event):
1510
+ if event.key() == Qt.Key_Z:
1511
+ self.zoom_button.click()
1512
+ if self.machine_window is not None:
1513
+ if event.key() == Qt.Key_A:
1514
+ self.machine_window.switch_foreground()
1515
+
1516
+
1517
+ def update_brush_cursor(self):
1518
+ """Update the cursor to show brush size"""
1519
+ if not self.brush_mode:
1520
+ return
1521
+
1522
+ # Create a pixmap for the cursor
1523
+ size = self.brush_size * 2 + 2 # Add padding for border
1524
+ pixmap = QPixmap(size, size)
1525
+ pixmap.fill(Qt.transparent)
1526
+
1527
+ # Create painter for the pixmap
1528
+ painter = QPainter(pixmap)
1529
+ painter.setRenderHint(QPainter.RenderHint.Antialiasing)
1530
+
1531
+ # Draw circle
1532
+ pen = QPen(Qt.white)
1533
+ pen.setWidth(1)
1534
+ painter.setPen(pen)
1535
+ painter.setBrush(Qt.transparent)
1536
+ painter.drawEllipse(1, 1, size-2, size-2)
1537
+
1538
+ # Create cursor from pixmap
1539
+ cursor = QCursor(pixmap)
1540
+ self.canvas.setCursor(cursor)
1541
+
1542
+ painter.end()
1543
+
1544
+ def get_line_points(self, x0, y0, x1, y1):
1545
+ """Get all points in a line between (x0,y0) and (x1,y1) using Bresenham's algorithm"""
1546
+ points = []
1547
+ dx = abs(x1 - x0)
1548
+ dy = abs(y1 - y0)
1549
+ x, y = x0, y0
1550
+ sx = 1 if x0 < x1 else -1
1551
+ sy = 1 if y0 < y1 else -1
1552
+
1553
+ if dx > dy:
1554
+ err = dx / 2.0
1555
+ while x != x1:
1556
+ points.append((x, y))
1557
+ err -= dy
1558
+ if err < 0:
1559
+ y += sy
1560
+ err += dx
1561
+ x += sx
1562
+ else:
1563
+ err = dy / 2.0
1564
+ while y != y1:
1565
+ points.append((x, y))
1566
+ err -= dx
1567
+ if err < 0:
1568
+ x += sx
1569
+ err += dy
1570
+ y += sy
1571
+
1572
+ points.append((x, y))
1573
+ return points
1439
1574
 
1440
1575
  def on_mouse_press(self, event):
1441
1576
  """Handle mouse press events."""
1442
1577
  if event.inaxes != self.ax:
1443
1578
  return
1579
+
1580
+ if event.button == 2:
1581
+ self.pan_button.click()
1582
+ return
1444
1583
 
1445
1584
  if self.zoom_mode:
1446
1585
  # Handle zoom mode press
@@ -1459,6 +1598,11 @@ class ImageViewerWindow(QMainWindow):
1459
1598
 
1460
1599
  self.ax.set_xlim([xdata - x_range, xdata + x_range])
1461
1600
  self.ax.set_ylim([ydata - y_range, ydata + y_range])
1601
+
1602
+ self.zoom_changed = True # Flag that zoom has changed
1603
+
1604
+ if not hasattr(self, 'zoom_changed'):
1605
+ self.zoom_changed = False
1462
1606
 
1463
1607
  elif event.button == 3: # Right click - zoom out
1464
1608
  x_range = (current_xlim[1] - current_xlim[0])
@@ -1476,6 +1620,11 @@ class ImageViewerWindow(QMainWindow):
1476
1620
  else:
1477
1621
  self.ax.set_xlim(new_xlim)
1478
1622
  self.ax.set_ylim(new_ylim)
1623
+
1624
+ self.zoom_changed = False # Flag that zoom has changed
1625
+
1626
+ if not hasattr(self, 'zoom_changed'):
1627
+ self.zoom_changed = False
1479
1628
 
1480
1629
  self.canvas.draw()
1481
1630
 
@@ -1483,6 +1632,50 @@ class ImageViewerWindow(QMainWindow):
1483
1632
  self.panning = True
1484
1633
  self.pan_start = (event.xdata, event.ydata)
1485
1634
  self.canvas.setCursor(Qt.CursorShape.ClosedHandCursor)
1635
+
1636
+ elif self.brush_mode:
1637
+ if event.inaxes != self.ax:
1638
+ return
1639
+
1640
+
1641
+ if event.button == 1 or event.button == 3:
1642
+
1643
+ if event.button == 3:
1644
+ self.erase = True
1645
+ else:
1646
+ self.erase = False
1647
+
1648
+ self.painting = True
1649
+ x, y = int(event.xdata), int(event.ydata)
1650
+ self.last_paint_pos = (x, y)
1651
+
1652
+ if self.foreground:
1653
+ channel = 2
1654
+ else:
1655
+ channel = 3
1656
+
1657
+
1658
+ # Paint at initial position
1659
+ self.paint_at_position(x, y, self.erase, channel)
1660
+
1661
+ current_xlim = self.ax.get_xlim() if hasattr(self, 'ax') and self.ax.get_xlim() != (0, 1) else None
1662
+ current_ylim = self.ax.get_ylim() if hasattr(self, 'ax') and self.ax.get_ylim() != (0, 1) else None
1663
+
1664
+
1665
+ self.canvas.draw()
1666
+ #self.update_display(preserve_zoom=(current_xlim, current_ylim))
1667
+ self.restore_channels = []
1668
+
1669
+
1670
+ for i in range(4):
1671
+ if i == channel:
1672
+ self.channel_visible[i] = True
1673
+ elif self.channel_data[i] is not None and self.channel_visible[i] == True:
1674
+ self.channel_visible[i] = False
1675
+ self.restore_channels.append(i)
1676
+ self.update_display(preserve_zoom = (current_xlim, current_ylim), begin_paint = True)
1677
+ self.update_display_slice(channel, preserve_zoom=(current_xlim, current_ylim))
1678
+
1486
1679
 
1487
1680
  elif event.button == 3: # Right click (for context menu)
1488
1681
  self.create_context_menu(event)
@@ -1492,12 +1685,32 @@ class ImageViewerWindow(QMainWindow):
1492
1685
  self.selection_start = (event.xdata, event.ydata)
1493
1686
  self.selecting = False # Will be set to True if the mouse moves while button is held
1494
1687
 
1688
+ def paint_at_position(self, center_x, center_y, erase = False, channel = 2):
1689
+ """Paint pixels within brush radius at given position"""
1690
+ if self.channel_data[channel] is None:
1691
+ return
1692
+
1693
+ if erase:
1694
+ val = 0
1695
+ else:
1696
+ val = 255
1697
+
1698
+ height, width = self.channel_data[channel][self.current_slice].shape
1699
+ radius = self.brush_size // 2
1700
+
1701
+ # Calculate brush area
1702
+ for y in range(max(0, center_y - radius), min(height, center_y + radius + 1)):
1703
+ for x in range(max(0, center_x - radius), min(width, center_x + radius + 1)):
1704
+ # Check if point is within circular brush area
1705
+ if (x - center_x) ** 2 + (y - center_y) ** 2 <= radius ** 2:
1706
+ self.channel_data[channel][self.current_slice][y, x] = val
1707
+
1495
1708
  def on_mouse_move(self, event):
1496
1709
  """Handle mouse movement events."""
1497
1710
  if event.inaxes != self.ax:
1498
1711
  return
1499
1712
 
1500
- if self.selection_start and not self.selecting and not self.pan_mode and not self.zoom_mode:
1713
+ if self.selection_start and not self.selecting and not self.pan_mode and not self.zoom_mode and not self.brush_mode:
1501
1714
  # If mouse has moved more than a tiny amount while button is held, start selection
1502
1715
  if (abs(event.xdata - self.selection_start[0]) > 1 or
1503
1716
  abs(event.ydata - self.selection_start[1]) > 1):
@@ -1519,6 +1732,7 @@ class ImageViewerWindow(QMainWindow):
1519
1732
  self.canvas.draw()
1520
1733
 
1521
1734
  elif self.panning and self.pan_start is not None:
1735
+
1522
1736
  # Calculate the movement
1523
1737
  dx = event.xdata - self.pan_start[0]
1524
1738
  dy = event.ydata - self.pan_start[1]
@@ -1554,6 +1768,39 @@ class ImageViewerWindow(QMainWindow):
1554
1768
  # Update pan start position
1555
1769
  self.pan_start = (event.xdata, event.ydata)
1556
1770
 
1771
+ elif self.painting and self.brush_mode:
1772
+ if event.inaxes != self.ax:
1773
+ return
1774
+
1775
+ x, y = int(event.xdata), int(event.ydata)
1776
+
1777
+ if self.foreground:
1778
+ channel = 2
1779
+ else:
1780
+ channel = 3
1781
+
1782
+
1783
+ if self.channel_data[2] is not None:
1784
+ current_xlim = self.ax.get_xlim() if hasattr(self, 'ax') and self.ax.get_xlim() != (0, 1) else None
1785
+ current_ylim = self.ax.get_ylim() if hasattr(self, 'ax') and self.ax.get_ylim() != (0, 1) else None
1786
+ height, width = self.channel_data[2][self.current_slice].shape
1787
+
1788
+ if hasattr(self, 'last_paint_pos'):
1789
+ last_x, last_y = self.last_paint_pos
1790
+ points = self.get_line_points(last_x, last_y, x, y)
1791
+
1792
+ # Paint at each point along the line
1793
+ for px, py in points:
1794
+ if 0 <= px < width and 0 <= py < height:
1795
+ self.paint_at_position(px, py, self.erase, channel)
1796
+
1797
+ self.last_paint_pos = (x, y)
1798
+
1799
+ self.canvas.draw()
1800
+ #self.update_display(preserve_zoom=(current_xlim, current_ylim))
1801
+ self.update_display_slice(channel, preserve_zoom=(current_xlim, current_ylim))
1802
+
1803
+
1557
1804
  def on_mouse_release(self, event):
1558
1805
  """Handle mouse release events."""
1559
1806
  if self.pan_mode:
@@ -1615,15 +1862,26 @@ class ImageViewerWindow(QMainWindow):
1615
1862
  elif not self.selecting and self.selection_start: # If we had a click but never started selection
1616
1863
  # Handle as a normal click
1617
1864
  self.on_mouse_click(event)
1865
+
1618
1866
 
1619
1867
  # Clean up
1620
1868
  self.selection_start = None
1621
1869
  self.selecting = False
1870
+
1871
+
1622
1872
  if self.selection_rect is not None:
1623
1873
  self.selection_rect.remove()
1624
1874
  self.selection_rect = None
1625
1875
  self.canvas.draw()
1626
1876
 
1877
+ if self.brush_mode:
1878
+ self.painting = False
1879
+ for i in self.restore_channels:
1880
+ self.channel_visible[i] = True
1881
+ current_xlim = self.ax.get_xlim() if hasattr(self, 'ax') and self.ax.get_xlim() != (0, 1) else None
1882
+ current_ylim = self.ax.get_ylim() if hasattr(self, 'ax') and self.ax.get_ylim() != (0, 1) else None
1883
+ self.update_display(preserve_zoom=(current_xlim, current_ylim))
1884
+
1627
1885
 
1628
1886
  def highlight_value_in_tables(self, clicked_value):
1629
1887
  """Helper method to find and highlight a value in both tables."""
@@ -1718,6 +1976,11 @@ class ImageViewerWindow(QMainWindow):
1718
1976
 
1719
1977
  self.ax.set_xlim([xdata - x_range, xdata + x_range])
1720
1978
  self.ax.set_ylim([ydata - y_range, ydata + y_range])
1979
+
1980
+ self.zoom_changed = True # Flag that zoom has changed
1981
+
1982
+ if not hasattr(self, 'zoom_changed'):
1983
+ self.zoom_changed = False
1721
1984
 
1722
1985
  elif event.button == 3: # Right click - zoom out
1723
1986
  x_range = (current_xlim[1] - current_xlim[0])
@@ -1735,6 +1998,11 @@ class ImageViewerWindow(QMainWindow):
1735
1998
  else:
1736
1999
  self.ax.set_xlim(new_xlim)
1737
2000
  self.ax.set_ylim(new_ylim)
2001
+
2002
+
2003
+ self.zoom_changed = False # Flag that zoom has changed
2004
+
2005
+
1738
2006
 
1739
2007
  self.canvas.draw()
1740
2008
 
@@ -1748,7 +2016,7 @@ class ImageViewerWindow(QMainWindow):
1748
2016
  x_idx = int(round(event.xdata))
1749
2017
  y_idx = int(round(event.ydata))
1750
2018
  # Check if Ctrl key is pressed (using matplotlib's key_press system)
1751
- ctrl_pressed = 'ctrl' in event.modifiers # Note: changed from 'control' to 'ctrl'
2019
+ ctrl_pressed = 'ctrl' in event.modifiers
1752
2020
  if self.channel_data[self.active_channel][self.current_slice, y_idx, x_idx] != 0:
1753
2021
  clicked_value = self.channel_data[self.active_channel][self.current_slice, y_idx, x_idx]
1754
2022
  else:
@@ -2636,6 +2904,7 @@ class ImageViewerWindow(QMainWindow):
2636
2904
  "TIFF Files (*.tif *.tiff)"
2637
2905
  )
2638
2906
  self.channel_data[channel_index] = tifffile.imread(filename)
2907
+
2639
2908
  if len(self.channel_data[channel_index].shape) == 2: # handle 2d data
2640
2909
  self.channel_data[channel_index] = np.expand_dims(self.channel_data[channel_index], axis=0)
2641
2910
 
@@ -2651,10 +2920,14 @@ class ImageViewerWindow(QMainWindow):
2651
2920
  if len(self.channel_data[channel_index].shape) == 4 and (channel_index == 0 or channel_index == 1):
2652
2921
  self.channel_data[channel_index] = self.reduce_rgb_dimension(self.channel_data[channel_index])
2653
2922
 
2923
+ reset_resize = False
2924
+
2654
2925
  for i in range(4): #Try to ensure users don't load in different sized arrays
2655
2926
  if self.channel_data[i] is None or i == channel_index or data:
2656
2927
  if self.highlight_overlay is not None: #Make sure highlight overlay is always the same shape as new images
2657
2928
  if self.channel_data[i].shape[:3] != self.highlight_overlay.shape:
2929
+ self.resizing = True
2930
+ reset_resize = True
2658
2931
  self.highlight_overlay = None
2659
2932
  continue
2660
2933
  else:
@@ -2725,7 +2998,7 @@ class ImageViewerWindow(QMainWindow):
2725
2998
  if len(self.original_shape) == 4:
2726
2999
  self.original_shape = (self.original_shape[0], self.original_shape[1], self.original_shape[2])
2727
3000
 
2728
- self.update_display()
3001
+ self.update_display(reset_resize = reset_resize)
2729
3002
 
2730
3003
 
2731
3004
 
@@ -2952,10 +3225,18 @@ class ImageViewerWindow(QMainWindow):
2952
3225
  self.channel_brightness[channel_index]['min'] = min_val / 255 #Accomodate 32 bit data?
2953
3226
  self.channel_brightness[channel_index]['max'] = max_val / 255
2954
3227
  self.update_display(preserve_zoom = (current_xlim, current_ylim))
3228
+
3229
+
3230
+
2955
3231
 
2956
- def update_display(self, preserve_zoom=None, dims = None, called = False):
3232
+ def update_display(self, preserve_zoom=None, dims = None, called = False, reset_resize = False, begin_paint = False):
2957
3233
  """Update the display with currently visible channels and highlight overlay."""
2958
3234
 
3235
+ if begin_paint:
3236
+ # Store/update the static background with current zoom level
3237
+ self.static_background = self.canvas.copy_from_bbox(self.ax.bbox)
3238
+
3239
+
2959
3240
  self.figure.clear()
2960
3241
 
2961
3242
  # Get active channels and their dimensions
@@ -3111,9 +3392,42 @@ class ImageViewerWindow(QMainWindow):
3111
3392
  if current_xlim is not None and current_ylim is not None:
3112
3393
  self.ax.set_xlim(current_xlim)
3113
3394
  self.ax.set_ylim(current_ylim)
3395
+ if reset_resize:
3396
+ self.resizing = False
3114
3397
 
3115
3398
  self.canvas.draw()
3116
3399
 
3400
+ def update_display_slice(self, channel, preserve_zoom=None):
3401
+ """Ultra minimal update that only changes the paint channel's data"""
3402
+ if not self.channel_visible[channel]:
3403
+ return
3404
+
3405
+ if preserve_zoom:
3406
+ current_xlim, current_ylim = preserve_zoom
3407
+ if current_xlim is not None and current_ylim is not None:
3408
+ self.ax.set_xlim(current_xlim)
3409
+ self.ax.set_ylim(current_ylim)
3410
+
3411
+
3412
+ # Find the existing image for channel (paint channel)
3413
+ channel_image = None
3414
+ for img in self.ax.images:
3415
+ if img.cmap.name == f'custom_{channel}':
3416
+ channel_image = img
3417
+ break
3418
+
3419
+ if channel_image is not None:
3420
+ # Update the data of the existing image
3421
+ channel_image.set_array(self.channel_data[channel][self.current_slice])
3422
+
3423
+ # Restore the static background (all other channels) at current zoom level
3424
+ self.canvas.restore_region(self.static_background)
3425
+ # Draw just our paint channel
3426
+ self.ax.draw_artist(channel_image)
3427
+ # Blit everything
3428
+ self.canvas.blit(self.ax.bbox)
3429
+ self.canvas.flush_events()
3430
+
3117
3431
  def show_netshow_dialog(self):
3118
3432
  dialog = NetShowDialog(self)
3119
3433
  dialog.exec()
@@ -5390,6 +5704,12 @@ class ThresholdDialog(QDialog):
5390
5704
  run_button.clicked.connect(self.thresh_mode)
5391
5705
  layout.addRow(run_button)
5392
5706
 
5707
+ # Add ML button
5708
+ ML = QPushButton("Machine Learning")
5709
+ ML.clicked.connect(self.start_ml)
5710
+ layout.addRow(ML)
5711
+
5712
+
5393
5713
  def thresh_mode(self):
5394
5714
 
5395
5715
  try:
@@ -5410,6 +5730,192 @@ class ThresholdDialog(QDialog):
5410
5730
  except:
5411
5731
  pass
5412
5732
 
5733
+ def start_ml(self):
5734
+
5735
+
5736
+ if self.parent().channel_data[2] is not None or self.parent().channel_data[3] is not None or self.parent().highlight_overlay is not None:
5737
+ if self.confirm_machine_dialog():
5738
+ pass
5739
+ else:
5740
+ return
5741
+ elif self.parent().channel_data[0] is None and self.parent().channel_data[1] is None:
5742
+ QMessageBox.critical(
5743
+ self,
5744
+ "Alert",
5745
+ "Requires the channel for segmentation to be loaded into either the nodes or edges channels"
5746
+ )
5747
+ return
5748
+
5749
+
5750
+ self.parent().machine_window = MachineWindow(self.parent())
5751
+ self.parent().machine_window.show() # Non-modal window
5752
+ self.accept()
5753
+
5754
+ def confirm_machine_dialog(self):
5755
+ """Shows a dialog asking user to confirm if they want to start the segmenter"""
5756
+ msg = QMessageBox()
5757
+ msg.setIcon(QMessageBox.Icon.Question)
5758
+ msg.setText("Alert")
5759
+ msg.setInformativeText("Use of this feature will require use of both overlay channels and the highlight overlay. Please save any data and return, or proceed if you do not need those overlays")
5760
+ msg.setWindowTitle("Proceed?")
5761
+ msg.setStandardButtons(QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No)
5762
+ return msg.exec() == QMessageBox.StandardButton.Yes
5763
+
5764
+
5765
+ class MachineWindow(QMainWindow):
5766
+
5767
+ def __init__(self, parent=None):
5768
+ super().__init__(parent)
5769
+
5770
+ self.setWindowTitle("Threshold")
5771
+
5772
+ # Create central widget and layout
5773
+ central_widget = QWidget()
5774
+ self.setCentralWidget(central_widget)
5775
+ layout = QVBoxLayout(central_widget)
5776
+
5777
+
5778
+ # Create form layout for inputs
5779
+ form_layout = QFormLayout()
5780
+
5781
+ layout.addLayout(form_layout)
5782
+
5783
+
5784
+
5785
+
5786
+ if self.parent().active_channel == 0:
5787
+ if self.parent().channel_data[0] is not None:
5788
+ active_data = self.parent().channel_data[0]
5789
+ else:
5790
+ active_data = self.parent().channel_data[1]
5791
+
5792
+
5793
+ array1 = np.zeros_like(active_data)
5794
+ array2 = np.zeros_like(active_data)
5795
+ array3 = np.zeros_like(active_data)
5796
+ self.parent().highlight_overlay = array3 #Clear this out for the segmenter to use
5797
+
5798
+ self.parent().load_channel(2, array1, True) #Temp for debugging
5799
+ self.parent().load_channel(3, array2, True) #Temp for debugging
5800
+
5801
+ self.parent().base_colors[2] = self.parent().color_dictionary['LIGHT_GREEN']
5802
+ self.parent().base_colors[3] = self.parent().color_dictionary['LIGHT_RED']
5803
+
5804
+
5805
+ # Set a reasonable default size
5806
+ self.setMinimumWidth(400)
5807
+ self.setMinimumHeight(400)
5808
+
5809
+ # Create zoom button and pan button
5810
+ buttons_widget = QWidget()
5811
+ buttons_layout = QHBoxLayout(buttons_widget)
5812
+
5813
+ # Create zoom button
5814
+ self.brush_button = QPushButton("🖌️")
5815
+ self.brush_button.setCheckable(True)
5816
+ self.brush_button.setFixedSize(40, 40)
5817
+ self.brush_button.clicked.connect(self.toggle_brush_mode)
5818
+ form_layout.addWidget(self.brush_button)
5819
+ self.brush_button.click()
5820
+
5821
+ self.fore_button = QPushButton("Foreground")
5822
+ self.fore_button.setCheckable(True)
5823
+ self.fore_button.setChecked(True)
5824
+ self.fore_button.clicked.connect(self.toggle_foreground)
5825
+ form_layout.addWidget(self.fore_button)
5826
+
5827
+ self.back_button = QPushButton("Background")
5828
+ self.back_button.setCheckable(True)
5829
+ self.back_button.setChecked(False)
5830
+ self.back_button.clicked.connect(self.toggle_background)
5831
+ form_layout.addWidget(self.back_button)
5832
+
5833
+ train_button = QPushButton("Train Model")
5834
+ train_button.clicked.connect(self.train_model)
5835
+ form_layout.addRow(train_button)
5836
+
5837
+ seg_button = QPushButton("Segment")
5838
+ seg_button.clicked.connect(self.segment)
5839
+ form_layout.addRow(seg_button)
5840
+
5841
+ self.trained = False
5842
+
5843
+
5844
+ self.segmenter = segmenter.InteractiveSegmenter(active_data, use_gpu=True)
5845
+
5846
+
5847
+
5848
+
5849
+ def toggle_foreground(self):
5850
+
5851
+ self.parent().foreground = self.fore_button.isChecked()
5852
+
5853
+ if self.parent().foreground:
5854
+ self.back_button.setChecked(False)
5855
+ else:
5856
+ self.back_button.setChecked(True)
5857
+
5858
+ def switch_foreground(self):
5859
+
5860
+ self.fore_button.click()
5861
+
5862
+ def toggle_background(self):
5863
+
5864
+ self.parent().foreground = not self.back_button.isChecked()
5865
+
5866
+ if not self.parent().foreground:
5867
+ self.fore_button.setChecked(False)
5868
+ else:
5869
+ self.fore_button.setChecked(True)
5870
+
5871
+
5872
+
5873
+ def toggle_brush_mode(self):
5874
+ """Toggle brush mode on/off"""
5875
+ self.parent().brush_mode = self.brush_button.isChecked()
5876
+ if self.parent().brush_mode:
5877
+ self.parent().pan_button.setChecked(False)
5878
+ self.parent().zoom_button.setChecked(False)
5879
+ self.parent().pan_mode = False
5880
+ self.parent().zoom_mode = False
5881
+ self.parent().update_brush_cursor()
5882
+ else:
5883
+ self.parent().zoom_button.click()
5884
+
5885
+ def silence_button(self):
5886
+ self.brush_button.setChecked(False)
5887
+
5888
+ def toggle_brush_button(self):
5889
+
5890
+ self.brush_button.click()
5891
+
5892
+ def train_model(self):
5893
+
5894
+ self.segmenter.train_batch(self.parent().channel_data[2], self.parent().channel_data[3])
5895
+ self.trained = True
5896
+
5897
+ def segment(self):
5898
+
5899
+ if not self.trained:
5900
+ return
5901
+ else:
5902
+ foreground_coords, background_coords = self.segmenter.segment_volume()
5903
+
5904
+ # Clean up when done
5905
+ self.segmenter.cleanup()
5906
+
5907
+ for z,y,x in foreground_coords:
5908
+ self.parent().highlight_overlay[z,y,x] = True
5909
+
5910
+ self.parent().update_display()
5911
+
5912
+ def closeEvent(self, event):
5913
+ if self.brush_button.isChecked():
5914
+ self.silence_button()
5915
+ self.toggle_brush_mode()
5916
+ self.parent().brush_mode = False
5917
+ self.parent().machine_window = None
5918
+
5413
5919
 
5414
5920
 
5415
5921
 
@@ -5430,7 +5936,16 @@ class ThresholdWindow(QMainWindow):
5430
5936
  self.bounds = False
5431
5937
  self.parent().bounds = False
5432
5938
  elif accepted_mode == 0:
5433
- self.histo_list = self.parent().channel_data[self.parent().active_channel].flatten().tolist()
5939
+ targ_shape = self.parent().channel_data[self.parent().active_channel].shape
5940
+ if (targ_shape[0] + targ_shape[1] + targ_shape[2]) > 2500: #Take a simpler histogram on big arrays
5941
+ temp_max = np.max(self.parent().channel_data[self.parent().active_channel])
5942
+ temp_min = np.min(self.parent().channel_data[self.parent().active_channel])
5943
+ temp_array = n3d.downsample(self.parent().channel_data[self.parent().active_channel], 5)
5944
+ self.histo_list = temp_array.flatten().tolist()
5945
+ self.histo_list.append(temp_min)
5946
+ self.histo_list.append(temp_max)
5947
+ else: #Otherwise just use full array data
5948
+ self.histo_list = self.parent().channel_data[self.parent().active_channel].flatten().tolist()
5434
5949
  self.bounds = True
5435
5950
  self.parent().bounds = True
5436
5951
 
@@ -5676,8 +6191,7 @@ class ThresholdWindow(QMainWindow):
5676
6191
  self.close()
5677
6192
 
5678
6193
  except Exception as e:
5679
- import traceback
5680
- print(traceback.format_exc())
6194
+
5681
6195
  QMessageBox.critical(
5682
6196
  self,
5683
6197
  "Error",
@@ -6104,9 +6618,17 @@ class WatershedDialog(QDialog):
6104
6618
  self.directory = QLineEdit()
6105
6619
  self.directory.setPlaceholderText("Leave empty for None")
6106
6620
  layout.addRow("Output Directory:", self.directory)
6621
+
6622
+ active_shape = self.parent().channel_data[self.parent().active_channel].shape[0]
6623
+
6624
+ if active_shape == 1:
6625
+ self.default = 0.2
6626
+ else:
6627
+ self.default = 0.05
6628
+
6107
6629
 
6108
6630
  # Proportion (default 0.1)
6109
- self.proportion = QLineEdit("0.05")
6631
+ self.proportion = QLineEdit(f"{self.default}")
6110
6632
  layout.addRow("Proportion:", self.proportion)
6111
6633
 
6112
6634
  # GPU checkbox (default True)
@@ -6142,9 +6664,9 @@ class WatershedDialog(QDialog):
6142
6664
 
6143
6665
  # Get proportion (0.1 if empty or invalid)
6144
6666
  try:
6145
- proportion = float(self.proportion.text()) if self.proportion.text() else 0.05
6667
+ proportion = float(self.proportion.text()) if self.proportion.text() else self.default
6146
6668
  except ValueError:
6147
- proportion = 0.05
6669
+ proportion = self.default
6148
6670
 
6149
6671
  # Get GPU state
6150
6672
  gpu = self.gpu.isChecked()
@@ -6171,7 +6693,8 @@ class WatershedDialog(QDialog):
6171
6693
  active_data = self.parent().channel_data[self.parent().active_channel]
6172
6694
  if active_data is None:
6173
6695
  raise ValueError("No active image selected")
6174
-
6696
+
6697
+
6175
6698
  # Call watershed method with parameters
6176
6699
  result = n3d.watershed(
6177
6700
  active_data,
@@ -6194,6 +6717,8 @@ class WatershedDialog(QDialog):
6194
6717
  self.accept()
6195
6718
 
6196
6719
  except Exception as e:
6720
+ import traceback
6721
+ print(traceback.format_exc())
6197
6722
  QMessageBox.critical(
6198
6723
  self,
6199
6724
  "Error",
@@ -0,0 +1,290 @@
1
+ from sklearn.ensemble import RandomForestClassifier
2
+ import numpy as np
3
+ import cupy as cp
4
+ import torch
5
+ from concurrent.futures import ThreadPoolExecutor
6
+ import threading
7
+ import cupyx.scipy.ndimage as cpx
8
+
9
+
10
+ class InteractiveSegmenter:
11
+ def __init__(self, image_3d, use_gpu=True):
12
+ self.image_3d = image_3d
13
+ self.patterns = []
14
+
15
+ self.use_gpu = use_gpu and cp.cuda.is_available()
16
+ if self.use_gpu:
17
+ print(f"Using GPU: {torch.cuda.get_device_name()}")
18
+ self.image_gpu = cp.asarray(image_3d)
19
+
20
+ self.model = RandomForestClassifier(
21
+ n_estimators=100,
22
+ n_jobs=-1,
23
+ max_depth=None
24
+ )
25
+ self.feature_cache = None
26
+ self.lock = threading.Lock()
27
+
28
+ def compute_feature_maps(self):
29
+ """Compute all feature maps using GPU acceleration"""
30
+ if not self.use_gpu:
31
+ return super().compute_feature_maps()
32
+
33
+ features = []
34
+ image = self.image_gpu
35
+ original_shape = self.image_3d.shape
36
+
37
+ # Gaussian smoothing at different scales
38
+ print("Obtaining gaussians")
39
+ for sigma in [0.5, 1.0, 2.0, 4.0]:
40
+ smooth = cp.asnumpy(self.gaussian_filter_gpu(image, sigma))
41
+ features.append(smooth)
42
+
43
+ print("Obtaining dif of gaussians")
44
+
45
+ # Difference of Gaussians
46
+ for (s1, s2) in [(1, 2), (2, 4)]:
47
+ g1 = self.gaussian_filter_gpu(image, s1)
48
+ g2 = self.gaussian_filter_gpu(image, s2)
49
+ dog = cp.asnumpy(g1 - g2)
50
+ features.append(dog)
51
+
52
+ # Convert image to PyTorch tensor for gradient operations
53
+ image_torch = torch.from_numpy(self.image_3d).cuda()
54
+ image_torch = image_torch.float().unsqueeze(0).unsqueeze(0)
55
+
56
+ # Calculate required padding
57
+ kernel_size = 3
58
+ padding = kernel_size // 2
59
+
60
+ # Create a single padded version with same padding
61
+ pad = torch.nn.functional.pad(image_torch, (padding, padding, padding, padding, padding, padding), mode='replicate')
62
+
63
+ print("Computing sobel kernels")
64
+
65
+ # Create sobel kernels
66
+ sobel_x = torch.tensor([-1, 0, 1], device='cuda').float().view(1,1,1,1,3)
67
+ sobel_y = torch.tensor([-1, 0, 1], device='cuda').float().view(1,1,1,3,1)
68
+ sobel_z = torch.tensor([-1, 0, 1], device='cuda').float().view(1,1,3,1,1)
69
+
70
+ # Compute gradients
71
+ print("Computing gradiants")
72
+
73
+ gx = torch.nn.functional.conv3d(pad, sobel_x, padding=0)[:,:,:original_shape[0],:original_shape[1],:original_shape[2]]
74
+ gy = torch.nn.functional.conv3d(pad, sobel_y, padding=0)[:,:,:original_shape[0],:original_shape[1],:original_shape[2]]
75
+ gz = torch.nn.functional.conv3d(pad, sobel_z, padding=0)[:,:,:original_shape[0],:original_shape[1],:original_shape[2]]
76
+
77
+ # Compute gradient magnitude
78
+ print("Computing gradiant mags")
79
+
80
+ gradient_magnitude = torch.sqrt(gx**2 + gy**2 + gz**2)
81
+ gradient_feature = gradient_magnitude.cpu().numpy().squeeze()
82
+
83
+ features.append(gradient_feature)
84
+
85
+ # Verify shapes
86
+ for i, feat in enumerate(features):
87
+ if feat.shape != original_shape:
88
+ raise ValueError(f"Feature {i} has shape {feat.shape}, expected {original_shape}")
89
+
90
+ return np.stack(features, axis=-1)
91
+
92
+ def gaussian_filter_gpu(self, image, sigma):
93
+ """GPU-accelerated Gaussian filter"""
94
+ # Create Gaussian kernel
95
+ result = cpx.gaussian_filter(image, sigma=sigma)
96
+
97
+ return result
98
+
99
+
100
+ def train(self):
101
+ """Train random forest on accumulated patterns"""
102
+ if len(self.patterns) < 2:
103
+ return
104
+
105
+ X = []
106
+ y = []
107
+ for pattern in self.patterns:
108
+ X.extend(pattern['features'])
109
+ y.extend([pattern['is_foreground']] * len(pattern['features']))
110
+
111
+ X = np.array(X)
112
+ y = np.array(y)
113
+ self.model.fit(X, y)
114
+ self.patterns = []
115
+
116
+ def process_chunk(self, chunk_coords):
117
+ """Process a chunk of coordinates"""
118
+ features = [self.feature_cache[z, y, x] for z, y, x in chunk_coords]
119
+ predictions = self.model.predict(features)
120
+
121
+ foreground = set()
122
+ background = set()
123
+ for coord, pred in zip(chunk_coords, predictions):
124
+ if pred:
125
+ foreground.add(coord)
126
+ else:
127
+ background.add(coord)
128
+
129
+ return foreground, background
130
+
131
+ def segment_volume(self, chunk_size=32):
132
+ """Segment volume using parallel processing of chunks"""
133
+ if self.feature_cache is None:
134
+ with self.lock:
135
+ if self.feature_cache is None:
136
+ self.feature_cache = self.compute_feature_maps()
137
+
138
+ # Create chunks of coordinates
139
+ chunks = []
140
+ for z in range(0, self.image_3d.shape[0], chunk_size):
141
+ for y in range(0, self.image_3d.shape[1], chunk_size):
142
+ for x in range(0, self.image_3d.shape[2], chunk_size):
143
+ chunk_coords = [
144
+ (zz, yy, xx)
145
+ for zz in range(z, min(z + chunk_size, self.image_3d.shape[0]))
146
+ for yy in range(y, min(y + chunk_size, self.image_3d.shape[1]))
147
+ for xx in range(x, min(x + chunk_size, self.image_3d.shape[2]))
148
+ ]
149
+ chunks.append(chunk_coords)
150
+
151
+ foreground_coords = set()
152
+ background_coords = set()
153
+
154
+ # Process chunks in parallel
155
+ with ThreadPoolExecutor() as executor:
156
+ futures = [executor.submit(self.process_chunk, chunk) for chunk in chunks]
157
+
158
+ for i, future in enumerate(futures):
159
+ fore, back = future.result()
160
+ foreground_coords.update(fore)
161
+ background_coords.update(back)
162
+ if i % 10 == 0:
163
+ print(f"Processed {i}/{len(chunks)} chunks")
164
+
165
+ return foreground_coords, background_coords
166
+
167
+ def cleanup(self):
168
+ """Clean up GPU memory"""
169
+ if self.use_gpu:
170
+ cp.get_default_memory_pool().free_all_blocks()
171
+ torch.cuda.empty_cache()
172
+
173
+ def train_batch(self, foreground_array, background_array):
174
+ """Train directly on foreground and background arrays"""
175
+ if self.feature_cache is None:
176
+ with self.lock:
177
+ if self.feature_cache is None:
178
+ self.feature_cache = self.compute_feature_maps()
179
+
180
+ # Get foreground coordinates and features
181
+ z_fore, y_fore, x_fore = np.where(foreground_array > 0)
182
+ foreground_features = self.feature_cache[z_fore, y_fore, x_fore]
183
+
184
+ # Get background coordinates and features
185
+ z_back, y_back, x_back = np.where(background_array > 0)
186
+ background_features = self.feature_cache[z_back, y_back, x_back]
187
+
188
+ # Combine features and labels
189
+ X = np.vstack([foreground_features, background_features])
190
+ y = np.hstack([np.ones(len(z_fore)), np.zeros(len(z_back))])
191
+
192
+ # Train the model
193
+ self.model.fit(X, y)
194
+
195
+ print("Done")
196
+
197
+
198
+
199
+
200
+
201
+
202
+
203
+
204
+
205
+ def segment_volume_subprocess(self, chunk_size=32, current_z=None, current_x=None, current_y=None):
206
+ """
207
+ Segment volume prioritizing chunks near user location.
208
+ Returns chunks as they're processed.
209
+ """
210
+ if self.feature_cache is None:
211
+ with self.lock:
212
+ if self.feature_cache is None:
213
+ self.feature_cache = self.compute_feature_maps()
214
+
215
+ # Create chunks with position information
216
+ chunks_info = []
217
+ for z in range(0, self.image_3d.shape[0], chunk_size):
218
+ for y in range(0, self.image_3d.shape[1], chunk_size):
219
+ for x in range(0, self.image_3d.shape[2], chunk_size):
220
+ chunk_coords = [
221
+ (zz, yy, xx)
222
+ for zz in range(z, min(z + chunk_size, self.image_3d.shape[0]))
223
+ for yy in range(y, min(y + chunk_size, self.image_3d.shape[1]))
224
+ for xx in range(x, min(x + chunk_size, self.image_3d.shape[2]))
225
+ ]
226
+
227
+ # Store chunk with its corner position
228
+ chunks_info.append({
229
+ 'coords': chunk_coords,
230
+ 'corner': (z, y, x),
231
+ 'processed': False
232
+ })
233
+
234
+ def get_chunk_priority(chunk):
235
+ """Calculate priority based on distance from user position"""
236
+ z, y, x = chunk['corner']
237
+ priority = 0
238
+
239
+ # Priority based on Z distance (always used)
240
+ if current_z is not None:
241
+ priority += abs(z - current_z)
242
+
243
+ # Add X/Y distance if provided
244
+ if current_x is not None and current_y is not None:
245
+ xy_distance = ((x - current_x) ** 2 + (y - current_y) ** 2) ** 0.5
246
+ priority += xy_distance
247
+
248
+ return priority
249
+
250
+ with ThreadPoolExecutor() as executor:
251
+ futures = {} # Track active futures
252
+
253
+ while True:
254
+ # Sort unprocessed chunks by priority
255
+ unprocessed_chunks = [c for c in chunks_info if not c['processed']]
256
+ if not unprocessed_chunks:
257
+ break
258
+
259
+ # Sort by distance from current position
260
+ unprocessed_chunks.sort(key=get_chunk_priority)
261
+
262
+ # Submit new chunks to replace completed ones
263
+ while len(futures) < executor._max_workers and unprocessed_chunks:
264
+ chunk = unprocessed_chunks.pop(0)
265
+ future = executor.submit(self.process_chunk, chunk['coords'])
266
+ futures[future] = chunk
267
+ chunk['processed'] = True
268
+
269
+ # Check completed futures
270
+ done, _ = concurrent.futures.wait(
271
+ futures.keys(),
272
+ timeout=0.1,
273
+ return_when=concurrent.futures.FIRST_COMPLETED
274
+ )
275
+
276
+ # Process completed chunks
277
+ for future in done:
278
+ chunk = futures[future]
279
+ fore, back = future.result()
280
+
281
+ # Yield chunk results with position information
282
+ yield {
283
+ 'foreground': fore,
284
+ 'background': back,
285
+ 'corner': chunk['corner'],
286
+ 'size': chunk_size
287
+ }
288
+
289
+ del futures[future]
290
+
@@ -383,42 +383,83 @@ def smart_label(binary_array, label_array, directory = None, GPU = True, predown
383
383
  return dilated_nodes_with_labels
384
384
 
385
385
  def compute_distance_transform_GPU(nodes):
386
+ is_pseudo_3d = nodes.shape[0] == 1
387
+ if is_pseudo_3d:
388
+ nodes = np.squeeze(nodes) # Convert to 2D for processing
389
+
386
390
  # Convert numpy array to CuPy array
387
391
  nodes_cp = cp.asarray(nodes)
388
392
 
389
393
  # Compute the distance transform on the GPU
390
- distance, nearest_label_indices = cpx.distance_transform_edt(nodes_cp, return_indices=True)
394
+ _, nearest_label_indices = cpx.distance_transform_edt(nodes_cp, return_indices=True)
391
395
 
392
396
  # Convert results back to numpy arrays
393
397
  nearest_label_indices_np = cp.asnumpy(nearest_label_indices)
394
398
 
399
+ if is_pseudo_3d:
400
+ # For 2D input, we get (2, H, W) but need (3, 1, H, W)
401
+ H, W = nearest_label_indices_np[0].shape
402
+ indices_4d = np.zeros((3, 1, H, W), dtype=nearest_label_indices_np.dtype)
403
+ indices_4d[1:, 0] = nearest_label_indices_np # Copy Y and X coordinates
404
+ # indices_4d[0] stays 0 for all Z coordinates
405
+ nearest_label_indices_np = indices_4d
406
+
407
+
408
+
409
+
395
410
  return nearest_label_indices_np
396
411
 
397
412
 
398
413
  def compute_distance_transform(nodes):
414
+ is_pseudo_3d = nodes.shape[0] == 1
415
+ if is_pseudo_3d:
416
+ nodes = np.squeeze(nodes) # Convert to 2D for processing
417
+
399
418
  distance, nearest_label_indices = distance_transform_edt(nodes, return_indices=True)
419
+
420
+ if is_pseudo_3d:
421
+ # For 2D input, we get (2, H, W) but need (3, 1, H, W)
422
+ H, W = nearest_label_indices_np[0].shape
423
+ indices_4d = np.zeros((3, 1, H, W), dtype=nearest_label_indices_np.dtype)
424
+ indices_4d[1:, 0] = nearest_label_indices_np # Copy Y and X coordinates
425
+ # indices_4d[0] stays 0 for all Z coordinates
426
+ nearest_label_indices_np = indices_4d
427
+
400
428
  return nearest_label_indices
401
429
 
402
430
 
403
431
 
404
432
  def compute_distance_transform_distance_GPU(nodes):
405
433
 
434
+ is_pseudo_3d = nodes.shape[0] == 1
435
+ if is_pseudo_3d:
436
+ nodes = np.squeeze(nodes) # Convert to 2D for processing
437
+
406
438
  # Convert numpy array to CuPy array
407
439
  nodes_cp = cp.asarray(nodes)
408
440
 
409
441
  # Compute the distance transform on the GPU
410
- distance, nearest_label_indices = cpx.distance_transform_edt(nodes_cp, return_indices=True)
442
+ distance, _ = cpx.distance_transform_edt(nodes_cp, return_indices=True)
411
443
 
412
444
  # Convert results back to numpy arrays
413
445
  distance = cp.asnumpy(distance)
446
+
447
+ if is_pseudo_3d:
448
+ np.expand_dims(distance, axis = 0)
414
449
 
415
450
  return distance
416
451
 
417
452
 
418
453
  def compute_distance_transform_distance(nodes):
419
454
 
455
+ is_pseudo_3d = nodes.shape[0] == 1
456
+ if is_pseudo_3d:
457
+ nodes = np.squeeze(nodes) # Convert to 2D for processing
458
+
420
459
  # Fallback to CPU if there's an issue with GPU computation
421
- distance, nearest_label_indices = distance_transform_edt(nodes, return_indices=True)
460
+ distance, _ = distance_transform_edt(nodes, return_indices=True)
461
+ if is_pseudo_3d:
462
+ np.expand_dims(distance, axis = 0)
422
463
  return distance
423
464
 
424
465
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: nettracer3d
3
- Version: 0.4.3
3
+ Version: 0.4.4
4
4
  Summary: Scripts for intializing and analyzing networks from segmentations of three dimensional images.
5
5
  Author-email: Liam McLaughlin <boom2449@gmail.com>
6
6
  Project-URL: User_Manual, https://drive.google.com/drive/folders/1fTkz3n4LN9_VxKRKC8lVQSlrz_wq0bVn?usp=drive_link
@@ -32,7 +32,7 @@ Requires-Dist: cupy-cuda12x; extra == "cuda12"
32
32
  Provides-Extra: cupy
33
33
  Requires-Dist: cupy; extra == "cupy"
34
34
 
35
- NetTracer3D is a python package developed for both 2D and 3D analysis of microscopic images in the .tif file format. It supports generation of 3D networks showing the relationships between objects (or nodes) in three dimensional space, either based on their own proximity or connectivity via connecting objects such as nerves or blood vessels. In addition to these functionalities are several advanced 3D data processing algorithms, such as labeling of branched structures or abstraction of branched structures into networks. Note that nettracer3d uses segmented data, which can be segmented from other softwares such as ImageJ and imported into NetTracer3D, although it does offer its own segmentation via intensity or volumetric thresholding. NetTracer3D currently has a fully functional GUI. To use the GUI, after installing the nettracer3d package via pip, enter the command 'nettracer3d' in your command prompt:
35
+ NetTracer3D is a python package developed for both 2D and 3D analysis of microscopic images in the .tif file format. It supports generation of 3D networks showing the relationships between objects (or nodes) in three dimensional space, either based on their own proximity or connectivity via connecting objects such as nerves or blood vessels. In addition to these functionalities are several advanced 3D data processing algorithms, such as labeling of branched structures or abstraction of branched structures into networks. Note that nettracer3d uses segmented data, which can be segmented from other softwares such as ImageJ and imported into NetTracer3D, although it does offer its own segmentation via intensity and volumetric thresholding, or random forest machine learning segmentation. NetTracer3D currently has a fully functional GUI. To use the GUI, after installing the nettracer3d package via pip, enter the command 'nettracer3d' in your command prompt:
36
36
 
37
37
 
38
38
  This gui is built from the PyQt6 package and therefore may not function on dockers or virtual envs that are unable to support PyQt6 displays. More advanced documentation (especially for the GUI) is coming down the line, but for now please see: https://drive.google.com/drive/folders/1fTkz3n4LN9_VxKRKC8lVQSlrz_wq0bVn?usp=drive_link
@@ -13,6 +13,7 @@ src/nettracer3d/network_draw.py
13
13
  src/nettracer3d/node_draw.py
14
14
  src/nettracer3d/proximity.py
15
15
  src/nettracer3d/run.py
16
+ src/nettracer3d/segmenter.py
16
17
  src/nettracer3d/simple_network.py
17
18
  src/nettracer3d/smart_dilate.py
18
19
  src/nettracer3d.egg-info/PKG-INFO
File without changes
File without changes