spacr 1.0.6__py3-none-any.whl → 1.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/gui_elements.py CHANGED
@@ -1,31 +1,12 @@
1
- import os, threading, time, sqlite3, webbrowser, pyautogui, random, cv2
1
+ import os, webbrowser, pyautogui, random, cv2
2
+ from tkinter import ttk, scrolledtext
2
3
  import tkinter as tk
3
4
  from tkinter import ttk
4
5
  import tkinter.font as tkFont
5
6
  from tkinter import filedialog
6
7
  from tkinter import font
7
- from queue import Queue
8
- from tkinter import Label, Frame, Button
9
8
  import numpy as np
10
- import pandas as pd
11
- from PIL import Image, ImageOps, ImageTk, ImageDraw, ImageFont, ImageEnhance
12
- from concurrent.futures import ThreadPoolExecutor
13
- from IPython.display import display, HTML
14
- import imageio.v2 as imageio
15
- from collections import deque
16
- from skimage.filters import threshold_otsu
17
- from skimage.exposure import rescale_intensity
18
- from skimage.draw import polygon, line
19
- from skimage.transform import resize
20
- from skimage.morphology import dilation, disk
21
- from skimage.segmentation import find_boundaries
22
- from skimage.util import img_as_ubyte
23
- from scipy.ndimage import binary_fill_holes, label, gaussian_filter
24
- from tkinter import ttk, scrolledtext
25
- from sklearn.model_selection import train_test_split
26
- from xgboost import XGBClassifier
27
- from sklearn.metrics import classification_report, confusion_matrix
28
-
9
+ from PIL import Image, ImageTk, ImageDraw, ImageFont, ImageEnhance
29
10
 
30
11
  fig = None
31
12
 
@@ -50,18 +31,12 @@ def create_menu_bar(root):
50
31
  gui_apps = {
51
32
  "Mask": lambda: initiate_root(root, settings_type='mask'),
52
33
  "Measure": lambda: initiate_root(root, settings_type='measure'),
53
- "Annotate (Beta)": lambda: initiate_root(root, settings_type='annotate'),
54
- #"Make Masks": lambda: initiate_root(root, settings_type='make_masks'),
55
34
  "Classify": lambda: initiate_root(root, settings_type='classify'),
56
- #"Umap": lambda: initiate_root(root, settings_type='umap'),
57
- #"Train Cellpose": lambda: initiate_root(root, settings_type='train_cellpose'),
58
35
  "ML Analyze": lambda: initiate_root(root, settings_type='ml_analyze'),
59
- #"Cellpose Masks": lambda: initiate_root(root, settings_type='cellpose_masks'),
60
- #"Cellpose All": lambda: initiate_root(root, settings_type='cellpose_all'),
61
36
  "Map Barcodes": lambda: initiate_root(root, settings_type='map_barcodes'),
62
37
  "Regression": lambda: initiate_root(root, settings_type='regression'),
63
38
  "Activation": lambda: initiate_root(root, settings_type='activation'),
64
- "Recruitment (Beta)": lambda: initiate_root(root, settings_type='recruitment')
39
+ "Recruitment": lambda: initiate_root(root, settings_type='recruitment')
65
40
  }
66
41
 
67
42
  # Create the menu bar
@@ -1344,1655 +1319,6 @@ class spacrToolTip:
1344
1319
  self.tooltip_window.destroy()
1345
1320
  self.tooltip_window = None
1346
1321
 
1347
- class ModifyMaskApp:
1348
- def __init__(self, root, folder_path, scale_factor):
1349
- self.root = root
1350
- self.folder_path = folder_path
1351
- self.scale_factor = scale_factor
1352
- self.image_filenames = sorted([f for f in os.listdir(folder_path) if f.endswith(('.png', '.jpg', '.jpeg', '.tif', '.tiff'))])
1353
- self.masks_folder = os.path.join(folder_path, 'masks')
1354
- self.current_image_index = 0
1355
- self.initialize_flags()
1356
- self.canvas_width = self.root.winfo_screenheight() -100
1357
- self.canvas_height = self.root.winfo_screenheight() -100
1358
- self.root.configure(bg='black')
1359
- self.setup_navigation_toolbar()
1360
- self.setup_mode_toolbar()
1361
- self.setup_function_toolbar()
1362
- self.setup_zoom_toolbar()
1363
- self.setup_canvas()
1364
- self.load_first_image()
1365
-
1366
- ####################################################################################################
1367
- # Helper functions#
1368
- ####################################################################################################
1369
-
1370
- def update_display(self):
1371
- if self.zoom_active:
1372
- self.display_zoomed_image()
1373
- else:
1374
- self.display_image()
1375
-
1376
- def update_original_mask_from_zoom(self):
1377
- y0, y1, x0, x1 = self.zoom_y0, self.zoom_y1, self.zoom_x0, self.zoom_x1
1378
- zoomed_mask_resized = resize(self.zoom_mask, (y1 - y0, x1 - x0), order=0, preserve_range=True).astype(np.uint8)
1379
- self.mask[y0:y1, x0:x1] = zoomed_mask_resized
1380
-
1381
- def update_original_mask(self, zoomed_mask, x0, x1, y0, y1):
1382
- actual_mask_region = self.mask[y0:y1, x0:x1]
1383
- target_shape = actual_mask_region.shape
1384
- resized_mask = resize(zoomed_mask, target_shape, order=0, preserve_range=True).astype(np.uint8)
1385
- if resized_mask.shape != actual_mask_region.shape:
1386
- raise ValueError(f"Shape mismatch: resized_mask {resized_mask.shape}, actual_mask_region {actual_mask_region.shape}")
1387
- self.mask[y0:y1, x0:x1] = np.maximum(actual_mask_region, resized_mask)
1388
- self.mask = self.mask.copy()
1389
- self.mask[y0:y1, x0:x1] = np.maximum(self.mask[y0:y1, x0:x1], resized_mask)
1390
- self.mask = self.mask.copy()
1391
-
1392
- def get_scaling_factors(self, img_width, img_height, canvas_width, canvas_height):
1393
- x_scale = img_width / canvas_width
1394
- y_scale = img_height / canvas_height
1395
- return x_scale, y_scale
1396
-
1397
- def canvas_to_image(self, x_canvas, y_canvas):
1398
- x_scale, y_scale = self.get_scaling_factors(
1399
- self.image.shape[1], self.image.shape[0],
1400
- self.canvas_width, self.canvas_height
1401
- )
1402
- x_image = int(x_canvas * x_scale)
1403
- y_image = int(y_canvas * y_scale)
1404
- return x_image, y_image
1405
-
1406
- def apply_zoom_on_enter(self, event):
1407
- if self.zoom_active and self.zoom_rectangle_start is not None:
1408
- self.set_zoom_rectangle_end(event)
1409
-
1410
- def normalize_image(self, image, lower_quantile, upper_quantile):
1411
- lower_bound = np.percentile(image, lower_quantile)
1412
- upper_bound = np.percentile(image, upper_quantile)
1413
- normalized = np.clip(image, lower_bound, upper_bound)
1414
- normalized = (normalized - lower_bound) / (upper_bound - lower_bound)
1415
- max_value = np.iinfo(image.dtype).max
1416
- normalized = (normalized * max_value).astype(image.dtype)
1417
- return normalized
1418
-
1419
- def resize_arrays(self, img, mask):
1420
- original_dtype = img.dtype
1421
- scaled_height = int(img.shape[0] * self.scale_factor)
1422
- scaled_width = int(img.shape[1] * self.scale_factor)
1423
- scaled_img = resize(img, (scaled_height, scaled_width), anti_aliasing=True, preserve_range=True)
1424
- scaled_mask = resize(mask, (scaled_height, scaled_width), order=0, anti_aliasing=False, preserve_range=True)
1425
- stretched_img = resize(scaled_img, (self.canvas_height, self.canvas_width), anti_aliasing=True, preserve_range=True)
1426
- stretched_mask = resize(scaled_mask, (self.canvas_height, self.canvas_width), order=0, anti_aliasing=False, preserve_range=True)
1427
- return stretched_img.astype(original_dtype), stretched_mask.astype(original_dtype)
1428
-
1429
- ####################################################################################################
1430
- #Initiate canvas elements#
1431
- ####################################################################################################
1432
-
1433
- def load_first_image(self):
1434
- self.image, self.mask = self.load_image_and_mask(self.current_image_index)
1435
- self.original_size = self.image.shape
1436
- self.image, self.mask = self.resize_arrays(self.image, self.mask)
1437
- self.display_image()
1438
-
1439
- def setup_canvas(self):
1440
- self.canvas = tk.Canvas(self.root, width=self.canvas_width, height=self.canvas_height, bg='black')
1441
- self.canvas.pack()
1442
- self.canvas.bind("<Motion>", self.update_mouse_info)
1443
-
1444
- def initialize_flags(self):
1445
- self.zoom_rectangle_start = None
1446
- self.zoom_rectangle_end = None
1447
- self.zoom_rectangle_id = None
1448
- self.zoom_x0 = None
1449
- self.zoom_y0 = None
1450
- self.zoom_x1 = None
1451
- self.zoom_y1 = None
1452
- self.zoom_mask = None
1453
- self.zoom_image = None
1454
- self.zoom_image_orig = None
1455
- self.zoom_scale = 1
1456
- self.drawing = False
1457
- self.zoom_active = False
1458
- self.magic_wand_active = False
1459
- self.brush_active = False
1460
- self.dividing_line_active = False
1461
- self.dividing_line_coords = []
1462
- self.current_dividing_line = None
1463
- self.lower_quantile = tk.StringVar(value="1.0")
1464
- self.upper_quantile = tk.StringVar(value="99.9")
1465
- self.magic_wand_tolerance = tk.StringVar(value="1000")
1466
-
1467
- def update_mouse_info(self, event):
1468
- x, y = event.x, event.y
1469
- intensity = "N/A"
1470
- mask_value = "N/A"
1471
- pixel_count = "N/A"
1472
- if self.zoom_active:
1473
- if 0 <= x < self.canvas_width and 0 <= y < self.canvas_height:
1474
- intensity = self.zoom_image_orig[y, x] if self.zoom_image_orig is not None else "N/A"
1475
- mask_value = self.zoom_mask[y, x] if self.zoom_mask is not None else "N/A"
1476
- else:
1477
- if 0 <= x < self.image.shape[1] and 0 <= y < self.image.shape[0]:
1478
- intensity = self.image[y, x]
1479
- mask_value = self.mask[y, x]
1480
- if mask_value != "N/A" and mask_value != 0:
1481
- pixel_count = np.sum(self.mask == mask_value)
1482
- self.intensity_label.config(text=f"Intensity: {intensity}")
1483
- self.mask_value_label.config(text=f"Mask: {mask_value}, Area: {pixel_count}")
1484
- self.mask_value_label.config(text=f"Mask: {mask_value}")
1485
- if mask_value != "N/A" and mask_value != 0:
1486
- self.pixel_count_label.config(text=f"Area: {pixel_count}")
1487
- else:
1488
- self.pixel_count_label.config(text="Area: N/A")
1489
-
1490
- def setup_navigation_toolbar(self):
1491
- navigation_toolbar = tk.Frame(self.root, bg='black')
1492
- navigation_toolbar.pack(side='top', fill='x')
1493
- prev_btn = tk.Button(navigation_toolbar, text="Previous", command=self.previous_image, bg='black', fg='white')
1494
- prev_btn.pack(side='left')
1495
- next_btn = tk.Button(navigation_toolbar, text="Next", command=self.next_image, bg='black', fg='white')
1496
- next_btn.pack(side='left')
1497
- save_btn = tk.Button(navigation_toolbar, text="Save", command=self.save_mask, bg='black', fg='white')
1498
- save_btn.pack(side='left')
1499
- self.intensity_label = tk.Label(navigation_toolbar, text="Image: N/A", bg='black', fg='white')
1500
- self.intensity_label.pack(side='right')
1501
- self.mask_value_label = tk.Label(navigation_toolbar, text="Mask: N/A", bg='black', fg='white')
1502
- self.mask_value_label.pack(side='right')
1503
- self.pixel_count_label = tk.Label(navigation_toolbar, text="Area: N/A", bg='black', fg='white')
1504
- self.pixel_count_label.pack(side='right')
1505
-
1506
- def setup_mode_toolbar(self):
1507
- self.mode_toolbar = tk.Frame(self.root, bg='black')
1508
- self.mode_toolbar.pack(side='top', fill='x')
1509
- self.draw_btn = tk.Button(self.mode_toolbar, text="Draw", command=self.toggle_draw_mode, bg='black', fg='white')
1510
- self.draw_btn.pack(side='left')
1511
- self.magic_wand_btn = tk.Button(self.mode_toolbar, text="Magic Wand", command=self.toggle_magic_wand_mode, bg='black', fg='white')
1512
- self.magic_wand_btn.pack(side='left')
1513
- tk.Label(self.mode_toolbar, text="Tolerance:", bg='black', fg='white').pack(side='left')
1514
- self.tolerance_entry = tk.Entry(self.mode_toolbar, textvariable=self.magic_wand_tolerance, bg='black', fg='white')
1515
- self.tolerance_entry.pack(side='left')
1516
- tk.Label(self.mode_toolbar, text="Max Pixels:", bg='black', fg='white').pack(side='left')
1517
- self.max_pixels_entry = tk.Entry(self.mode_toolbar, bg='black', fg='white')
1518
- self.max_pixels_entry.insert(0, "1000")
1519
- self.max_pixels_entry.pack(side='left')
1520
- self.erase_btn = tk.Button(self.mode_toolbar, text="Erase", command=self.toggle_erase_mode, bg='black', fg='white')
1521
- self.erase_btn.pack(side='left')
1522
- self.brush_btn = tk.Button(self.mode_toolbar, text="Brush", command=self.toggle_brush_mode, bg='black', fg='white')
1523
- self.brush_btn.pack(side='left')
1524
- self.brush_size_entry = tk.Entry(self.mode_toolbar, bg='black', fg='white')
1525
- self.brush_size_entry.insert(0, "10")
1526
- self.brush_size_entry.pack(side='left')
1527
- tk.Label(self.mode_toolbar, text="Brush Size:", bg='black', fg='white').pack(side='left')
1528
- self.dividing_line_btn = tk.Button(self.mode_toolbar, text="Dividing Line", command=self.toggle_dividing_line_mode, bg='black', fg='white')
1529
- self.dividing_line_btn.pack(side='left')
1530
-
1531
- def setup_function_toolbar(self):
1532
- self.function_toolbar = tk.Frame(self.root, bg='black')
1533
- self.function_toolbar.pack(side='top', fill='x')
1534
- self.fill_btn = tk.Button(self.function_toolbar, text="Fill", command=self.fill_objects, bg='black', fg='white')
1535
- self.fill_btn.pack(side='left')
1536
- self.relabel_btn = tk.Button(self.function_toolbar, text="Relabel", command=self.relabel_objects, bg='black', fg='white')
1537
- self.relabel_btn.pack(side='left')
1538
- self.clear_btn = tk.Button(self.function_toolbar, text="Clear", command=self.clear_objects, bg='black', fg='white')
1539
- self.clear_btn.pack(side='left')
1540
- self.invert_btn = tk.Button(self.function_toolbar, text="Invert", command=self.invert_mask, bg='black', fg='white')
1541
- self.invert_btn.pack(side='left')
1542
- remove_small_btn = tk.Button(self.function_toolbar, text="Remove Small", command=self.remove_small_objects, bg='black', fg='white')
1543
- remove_small_btn.pack(side='left')
1544
- tk.Label(self.function_toolbar, text="Min Area:", bg='black', fg='white').pack(side='left')
1545
- self.min_area_entry = tk.Entry(self.function_toolbar, bg='black', fg='white')
1546
- self.min_area_entry.insert(0, "100") # Default minimum area
1547
- self.min_area_entry.pack(side='left')
1548
-
1549
- def setup_zoom_toolbar(self):
1550
- self.zoom_toolbar = tk.Frame(self.root, bg='black')
1551
- self.zoom_toolbar.pack(side='top', fill='x')
1552
- self.zoom_btn = tk.Button(self.zoom_toolbar, text="Zoom", command=self.toggle_zoom_mode, bg='black', fg='white')
1553
- self.zoom_btn.pack(side='left')
1554
- self.normalize_btn = tk.Button(self.zoom_toolbar, text="Apply Normalization", command=self.apply_normalization, bg='black', fg='white')
1555
- self.normalize_btn.pack(side='left')
1556
- tk.Label(self.zoom_toolbar, text="Lower Percentile:", bg='black', fg='white').pack(side='left')
1557
- self.lower_entry = tk.Entry(self.zoom_toolbar, textvariable=self.lower_quantile, bg='black', fg='white')
1558
- self.lower_entry.pack(side='left')
1559
-
1560
- tk.Label(self.zoom_toolbar, text="Upper Percentile:", bg='black', fg='white').pack(side='left')
1561
- self.upper_entry = tk.Entry(self.zoom_toolbar, textvariable=self.upper_quantile, bg='black', fg='white')
1562
- self.upper_entry.pack(side='left')
1563
-
1564
- def load_image_and_mask(self, index):
1565
- # Load the image
1566
- image_path = os.path.join(self.folder_path, self.image_filenames[index])
1567
- image = imageio.imread(image_path)
1568
- print(f"Original Image shape: {image.shape}, dtype: {image.dtype}")
1569
-
1570
- # Handle multi-channel or transparency issues
1571
- if image.ndim == 3:
1572
- if image.shape[2] == 4: # If the image has an alpha channel (RGBA)
1573
- image = image[..., :3] # Remove the alpha channel
1574
-
1575
- # Convert RGB to grayscale using weighted average
1576
- image = np.dot(image[..., :3], [0.2989, 0.5870, 0.1140]).astype(np.uint8)
1577
- print(f"Converted to grayscale: {image.shape}")
1578
-
1579
- # Ensure the shape is (height, width) without extra channel
1580
- if image.ndim == 3 and image.shape[2] == 1:
1581
- image = np.squeeze(image, axis=-1)
1582
-
1583
- if image.dtype != np.uint16:
1584
- # Scale the image to fit the 16-bit range (0–65535)
1585
- image = (image / image.max() * 65535).astype(np.uint16)
1586
- # eventually remove this images should not have to be 16 bit look into downstream function (non 16bit images are jsut black)
1587
-
1588
- # Load the corresponding mask
1589
- mask_path = os.path.join(self.masks_folder, self.image_filenames[index])
1590
- if os.path.exists(mask_path):
1591
- print(f'Loading mask: {mask_path} for image: {image_path}')
1592
- mask = imageio.imread(mask_path)
1593
-
1594
- # Ensure mask is uint8
1595
- if mask.dtype != np.uint8:
1596
- mask = (mask / mask.max() * 255).astype(np.uint8)
1597
- else:
1598
- # Create a new mask with the same size as the image
1599
- mask = np.zeros(image.shape[:2], dtype=np.uint8)
1600
- print(f'Loaded new mask for image: {image_path}')
1601
-
1602
- return image, mask
1603
-
1604
- ####################################################################################################
1605
- # Image Display functions#
1606
- ####################################################################################################
1607
- def display_image(self):
1608
- if self.zoom_rectangle_id is not None:
1609
- self.canvas.delete(self.zoom_rectangle_id)
1610
- self.zoom_rectangle_id = None
1611
- lower_quantile = float(self.lower_quantile.get()) if self.lower_quantile.get() else 1.0
1612
- upper_quantile = float(self.upper_quantile.get()) if self.upper_quantile.get() else 99.9
1613
- normalized = self.normalize_image(self.image, lower_quantile, upper_quantile)
1614
- combined = self.overlay_mask_on_image(normalized, self.mask)
1615
- self.tk_image = ImageTk.PhotoImage(image=Image.fromarray(combined))
1616
- self.canvas.create_image(0, 0, anchor='nw', image=self.tk_image)
1617
-
1618
- def display_zoomed_image(self):
1619
- if self.zoom_rectangle_start and self.zoom_rectangle_end:
1620
- # Convert canvas coordinates to image coordinates
1621
- x0, y0 = self.canvas_to_image(*self.zoom_rectangle_start)
1622
- x1, y1 = self.canvas_to_image(*self.zoom_rectangle_end)
1623
- x0, x1 = min(x0, x1), max(x0, x1)
1624
- y0, y1 = min(y0, y1), max(y0, y1)
1625
- self.zoom_x0 = x0
1626
- self.zoom_y0 = y0
1627
- self.zoom_x1 = x1
1628
- self.zoom_y1 = y1
1629
- # Normalize the entire image
1630
- lower_quantile = float(self.lower_quantile.get()) if self.lower_quantile.get() else 1.0
1631
- upper_quantile = float(self.upper_quantile.get()) if self.upper_quantile.get() else 99.9
1632
- normalized_image = self.normalize_image(self.image, lower_quantile, upper_quantile)
1633
- # Extract the zoomed portion of the normalized image and mask
1634
- self.zoom_image = normalized_image[y0:y1, x0:x1]
1635
- self.zoom_image_orig = self.image[y0:y1, x0:x1]
1636
- self.zoom_mask = self.mask[y0:y1, x0:x1]
1637
- original_mask_area = self.mask.shape[0] * self.mask.shape[1]
1638
- zoom_mask_area = self.zoom_mask.shape[0] * self.zoom_mask.shape[1]
1639
- if original_mask_area > 0:
1640
- self.zoom_scale = original_mask_area/zoom_mask_area
1641
- # Resize the zoomed image and mask to fit the canvas
1642
- canvas_height = self.canvas.winfo_height()
1643
- canvas_width = self.canvas.winfo_width()
1644
-
1645
- if self.zoom_image.size > 0 and canvas_height > 0 and canvas_width > 0:
1646
- self.zoom_image = resize(self.zoom_image, (canvas_height, canvas_width), preserve_range=True).astype(self.zoom_image.dtype)
1647
- self.zoom_image_orig = resize(self.zoom_image_orig, (canvas_height, canvas_width), preserve_range=True).astype(self.zoom_image_orig.dtype)
1648
- #self.zoom_mask = resize(self.zoom_mask, (canvas_height, canvas_width), preserve_range=True).astype(np.uint8)
1649
- self.zoom_mask = resize(self.zoom_mask, (canvas_height, canvas_width), order=0, preserve_range=True).astype(np.uint8)
1650
- combined = self.overlay_mask_on_image(self.zoom_image, self.zoom_mask)
1651
- self.tk_image = ImageTk.PhotoImage(image=Image.fromarray(combined))
1652
- self.canvas.create_image(0, 0, anchor='nw', image=self.tk_image)
1653
-
1654
- def overlay_mask_on_image(self, image, mask, alpha=0.5):
1655
- if len(image.shape) == 2:
1656
- image = np.stack((image,) * 3, axis=-1)
1657
- mask = mask.astype(np.int32)
1658
- max_label = np.max(mask)
1659
- np.random.seed(0)
1660
- colors = np.random.randint(0, 255, size=(max_label + 1, 3), dtype=np.uint8)
1661
- colors[0] = [0, 0, 0] # background color
1662
- colored_mask = colors[mask]
1663
- image_8bit = (image / 256).astype(np.uint8)
1664
- # Blend the mask and the image with transparency
1665
- combined_image = np.where(mask[..., None] > 0,
1666
- np.clip(image_8bit * (1 - alpha) + colored_mask * alpha, 0, 255),
1667
- image_8bit)
1668
- # Convert the final image back to uint8
1669
- combined_image = combined_image.astype(np.uint8)
1670
- return combined_image
1671
-
1672
- ####################################################################################################
1673
- # Navigation functions#
1674
- ####################################################################################################
1675
-
1676
- def previous_image(self):
1677
- if self.current_image_index > 0:
1678
- self.current_image_index -= 1
1679
- self.initialize_flags()
1680
- self.image, self.mask = self.load_image_and_mask(self.current_image_index)
1681
- self.original_size = self.image.shape
1682
- self.image, self.mask = self.resize_arrays(self.image, self.mask)
1683
- self.display_image()
1684
-
1685
- def next_image(self):
1686
- if self.current_image_index < len(self.image_filenames) - 1:
1687
- self.current_image_index += 1
1688
- self.initialize_flags()
1689
- self.image, self.mask = self.load_image_and_mask(self.current_image_index)
1690
- self.original_size = self.image.shape
1691
- self.image, self.mask = self.resize_arrays(self.image, self.mask)
1692
- self.display_image()
1693
-
1694
- def save_mask(self):
1695
- if self.current_image_index < len(self.image_filenames):
1696
- original_size = self.original_size
1697
- if self.mask.shape != original_size:
1698
- resized_mask = resize(self.mask, original_size, order=0, preserve_range=True).astype(np.uint16)
1699
- else:
1700
- resized_mask = self.mask
1701
- resized_mask, _ = label(resized_mask > 0)
1702
- save_folder = os.path.join(self.folder_path, 'masks')
1703
- if not os.path.exists(save_folder):
1704
- os.makedirs(save_folder)
1705
- image_filename = os.path.splitext(self.image_filenames[self.current_image_index])[0] + '.tif'
1706
- save_path = os.path.join(save_folder, image_filename)
1707
-
1708
- print(f"Saving mask to: {save_path}") # Debug print
1709
- imageio.imwrite(save_path, resized_mask)
1710
-
1711
- ####################################################################################################
1712
- # Zoom Functions #
1713
- ####################################################################################################
1714
- def set_zoom_rectangle_start(self, event):
1715
- if self.zoom_active:
1716
- self.zoom_rectangle_start = (event.x, event.y)
1717
-
1718
- def set_zoom_rectangle_end(self, event):
1719
- if self.zoom_active:
1720
- self.zoom_rectangle_end = (event.x, event.y)
1721
- if self.zoom_rectangle_id is not None:
1722
- self.canvas.delete(self.zoom_rectangle_id)
1723
- self.zoom_rectangle_id = None
1724
- self.display_zoomed_image()
1725
- self.canvas.unbind("<Motion>")
1726
- self.canvas.unbind("<Button-1>")
1727
- self.canvas.unbind("<Button-3>")
1728
- self.canvas.bind("<Motion>", self.update_mouse_info)
1729
-
1730
- def update_zoom_box(self, event):
1731
- if self.zoom_active and self.zoom_rectangle_start is not None:
1732
- if self.zoom_rectangle_id is not None:
1733
- self.canvas.delete(self.zoom_rectangle_id)
1734
- # Assuming event.x and event.y are already in image coordinates
1735
- self.zoom_rectangle_end = (event.x, event.y)
1736
- x0, y0 = self.zoom_rectangle_start
1737
- x1, y1 = self.zoom_rectangle_end
1738
- self.zoom_rectangle_id = self.canvas.create_rectangle(x0, y0, x1, y1, outline="red", width=2)
1739
-
1740
- ####################################################################################################
1741
- # Mode activation#
1742
- ####################################################################################################
1743
-
1744
- def toggle_zoom_mode(self):
1745
- if not self.zoom_active:
1746
- self.brush_btn.config(text="Brush")
1747
- self.canvas.unbind("<B1-Motion>")
1748
- self.canvas.unbind("<B3-Motion>")
1749
- self.canvas.unbind("<ButtonRelease-1>")
1750
- self.canvas.unbind("<ButtonRelease-3>")
1751
- self.zoom_active = True
1752
- self.drawing = False
1753
- self.magic_wand_active = False
1754
- self.erase_active = False
1755
- self.brush_active = False
1756
- self.dividing_line_active = False
1757
- self.draw_btn.config(text="Draw")
1758
- self.erase_btn.config(text="Erase")
1759
- self.magic_wand_btn.config(text="Magic Wand")
1760
- self.zoom_btn.config(text="Zoom ON")
1761
- self.dividing_line_btn.config(text="Dividing Line")
1762
- self.canvas.unbind("<Button-1>")
1763
- self.canvas.unbind("<Button-3>")
1764
- self.canvas.unbind("<Motion>")
1765
- self.canvas.bind("<Button-1>", self.set_zoom_rectangle_start)
1766
- self.canvas.bind("<Button-3>", self.set_zoom_rectangle_end)
1767
- self.canvas.bind("<Motion>", self.update_zoom_box)
1768
- else:
1769
- self.zoom_active = False
1770
- self.zoom_btn.config(text="Zoom")
1771
- self.canvas.unbind("<Button-1>")
1772
- self.canvas.unbind("<Button-3>")
1773
- self.canvas.unbind("<Motion>")
1774
- self.zoom_rectangle_start = self.zoom_rectangle_end = None
1775
- self.zoom_rectangle_id = None
1776
- self.display_image()
1777
- self.canvas.bind("<Motion>", self.update_mouse_info)
1778
- self.zoom_rectangle_start = None
1779
- self.zoom_rectangle_end = None
1780
- self.zoom_rectangle_id = None
1781
- self.zoom_x0 = None
1782
- self.zoom_y0 = None
1783
- self.zoom_x1 = None
1784
- self.zoom_y1 = None
1785
- self.zoom_mask = None
1786
- self.zoom_image = None
1787
- self.zoom_image_orig = None
1788
-
1789
- def toggle_brush_mode(self):
1790
- self.brush_active = not self.brush_active
1791
- if self.brush_active:
1792
- self.drawing = False
1793
- self.magic_wand_active = False
1794
- self.erase_active = False
1795
- self.brush_btn.config(text="Brush ON")
1796
- self.draw_btn.config(text="Draw")
1797
- self.erase_btn.config(text="Erase")
1798
- self.magic_wand_btn.config(text="Magic Wand")
1799
- self.canvas.unbind("<Button-1>")
1800
- self.canvas.unbind("<Button-3>")
1801
- self.canvas.unbind("<Motion>")
1802
- self.canvas.bind("<B1-Motion>", self.apply_brush) # Left click and drag to apply brush
1803
- self.canvas.bind("<B3-Motion>", self.erase_brush) # Right click and drag to erase with brush
1804
- self.canvas.bind("<ButtonRelease-1>", self.apply_brush_release) # Left button release
1805
- self.canvas.bind("<ButtonRelease-3>", self.erase_brush_release) # Right button release
1806
- else:
1807
- self.brush_active = False
1808
- self.brush_btn.config(text="Brush")
1809
- self.canvas.unbind("<B1-Motion>")
1810
- self.canvas.unbind("<B3-Motion>")
1811
- self.canvas.unbind("<ButtonRelease-1>")
1812
- self.canvas.unbind("<ButtonRelease-3>")
1813
-
1814
- def image_to_canvas(self, x_image, y_image):
1815
- x_scale, y_scale = self.get_scaling_factors(
1816
- self.image.shape[1], self.image.shape[0],
1817
- self.canvas_width, self.canvas_height
1818
- )
1819
- x_canvas = int(x_image / x_scale)
1820
- y_canvas = int(y_image / y_scale)
1821
- return x_canvas, y_canvas
1822
-
1823
- def toggle_dividing_line_mode(self):
1824
- self.dividing_line_active = not self.dividing_line_active
1825
- if self.dividing_line_active:
1826
- self.drawing = False
1827
- self.magic_wand_active = False
1828
- self.erase_active = False
1829
- self.brush_active = False
1830
- self.draw_btn.config(text="Draw")
1831
- self.erase_btn.config(text="Erase")
1832
- self.magic_wand_btn.config(text="Magic Wand")
1833
- self.brush_btn.config(text="Brush")
1834
- self.dividing_line_btn.config(text="Dividing Line ON")
1835
- self.canvas.unbind("<Button-1>")
1836
- self.canvas.unbind("<ButtonRelease-1>")
1837
- self.canvas.unbind("<Motion>")
1838
- self.canvas.bind("<Button-1>", self.start_dividing_line)
1839
- self.canvas.bind("<ButtonRelease-1>", self.finish_dividing_line)
1840
- self.canvas.bind("<Motion>", self.update_dividing_line_preview)
1841
- else:
1842
- print("Dividing Line Mode: OFF")
1843
- self.dividing_line_active = False
1844
- self.dividing_line_btn.config(text="Dividing Line")
1845
- self.canvas.unbind("<Button-1>")
1846
- self.canvas.unbind("<ButtonRelease-1>")
1847
- self.canvas.unbind("<Motion>")
1848
- self.display_image()
1849
-
1850
- def start_dividing_line(self, event):
1851
- if self.dividing_line_active:
1852
- self.dividing_line_coords = [(event.x, event.y)]
1853
- self.current_dividing_line = self.canvas.create_line(event.x, event.y, event.x, event.y, fill="red", width=2)
1854
-
1855
- def finish_dividing_line(self, event):
1856
- if self.dividing_line_active:
1857
- self.dividing_line_coords.append((event.x, event.y))
1858
- if self.zoom_active:
1859
- self.dividing_line_coords = [self.canvas_to_image(x, y) for x, y in self.dividing_line_coords]
1860
- self.apply_dividing_line()
1861
- self.canvas.delete(self.current_dividing_line)
1862
- self.current_dividing_line = None
1863
-
1864
- def update_dividing_line_preview(self, event):
1865
- if self.dividing_line_active and self.dividing_line_coords:
1866
- x, y = event.x, event.y
1867
- if self.zoom_active:
1868
- x, y = self.canvas_to_image(x, y)
1869
- self.dividing_line_coords.append((x, y))
1870
- canvas_coords = [(self.image_to_canvas(*pt) if self.zoom_active else pt) for pt in self.dividing_line_coords]
1871
- flat_canvas_coords = [coord for pt in canvas_coords for coord in pt]
1872
- self.canvas.coords(self.current_dividing_line, *flat_canvas_coords)
1873
-
1874
- def apply_dividing_line(self):
1875
- if self.dividing_line_coords:
1876
- coords = self.dividing_line_coords
1877
- if self.zoom_active:
1878
- coords = [self.canvas_to_image(x, y) for x, y in coords]
1879
-
1880
- rr, cc = [], []
1881
- for (x0, y0), (x1, y1) in zip(coords[:-1], coords[1:]):
1882
- line_rr, line_cc = line(y0, x0, y1, x1)
1883
- rr.extend(line_rr)
1884
- cc.extend(line_cc)
1885
- rr, cc = np.array(rr), np.array(cc)
1886
-
1887
- mask_copy = self.mask.copy()
1888
-
1889
- if self.zoom_active:
1890
- # Update the zoomed mask
1891
- self.zoom_mask[rr, cc] = 0
1892
- # Reflect changes to the original mask
1893
- y0, y1, x0, x1 = self.zoom_y0, self.zoom_y1, self.zoom_x0, self.zoom_x1
1894
- zoomed_mask_resized_back = resize(self.zoom_mask, (y1 - y0, x1 - x0), order=0, preserve_range=True).astype(np.uint8)
1895
- self.mask[y0:y1, x0:x1] = zoomed_mask_resized_back
1896
- else:
1897
- # Directly update the original mask
1898
- mask_copy[rr, cc] = 0
1899
- self.mask = mask_copy
1900
-
1901
- labeled_mask, num_labels = label(self.mask > 0)
1902
- self.mask = labeled_mask
1903
- self.update_display()
1904
-
1905
- self.dividing_line_coords = []
1906
- self.canvas.unbind("<Button-1>")
1907
- self.canvas.unbind("<ButtonRelease-1>")
1908
- self.canvas.unbind("<Motion>")
1909
- self.dividing_line_active = False
1910
- self.dividing_line_btn.config(text="Dividing Line")
1911
-
1912
- def toggle_draw_mode(self):
1913
- self.drawing = not self.drawing
1914
- if self.drawing:
1915
- self.brush_btn.config(text="Brush")
1916
- self.canvas.unbind("<B1-Motion>")
1917
- self.canvas.unbind("<B3-Motion>")
1918
- self.canvas.unbind("<ButtonRelease-1>")
1919
- self.canvas.unbind("<ButtonRelease-3>")
1920
- self.magic_wand_active = False
1921
- self.erase_active = False
1922
- self.brush_active = False
1923
- self.draw_btn.config(text="Draw ON")
1924
- self.magic_wand_btn.config(text="Magic Wand")
1925
- self.erase_btn.config(text="Erase")
1926
- self.draw_coordinates = []
1927
- self.canvas.unbind("<Button-1>")
1928
- self.canvas.unbind("<Motion>")
1929
- self.canvas.bind("<B1-Motion>", self.draw)
1930
- self.canvas.bind("<ButtonRelease-1>", self.finish_drawing)
1931
- else:
1932
- self.drawing = False
1933
- self.draw_btn.config(text="Draw")
1934
- self.canvas.unbind("<B1-Motion>")
1935
- self.canvas.unbind("<ButtonRelease-1>")
1936
-
1937
- def toggle_magic_wand_mode(self):
1938
- self.magic_wand_active = not self.magic_wand_active
1939
- if self.magic_wand_active:
1940
- self.brush_btn.config(text="Brush")
1941
- self.canvas.unbind("<B1-Motion>")
1942
- self.canvas.unbind("<B3-Motion>")
1943
- self.canvas.unbind("<ButtonRelease-1>")
1944
- self.canvas.unbind("<ButtonRelease-3>")
1945
- self.drawing = False
1946
- self.erase_active = False
1947
- self.brush_active = False
1948
- self.draw_btn.config(text="Draw")
1949
- self.erase_btn.config(text="Erase")
1950
- self.magic_wand_btn.config(text="Magic Wand ON")
1951
- self.canvas.bind("<Button-1>", self.use_magic_wand)
1952
- self.canvas.bind("<Button-3>", self.use_magic_wand)
1953
- else:
1954
- self.magic_wand_btn.config(text="Magic Wand")
1955
- self.canvas.unbind("<Button-1>")
1956
- self.canvas.unbind("<Button-3>")
1957
-
1958
- def toggle_erase_mode(self):
1959
- self.erase_active = not self.erase_active
1960
- if self.erase_active:
1961
- self.brush_btn.config(text="Brush")
1962
- self.canvas.unbind("<B1-Motion>")
1963
- self.canvas.unbind("<B3-Motion>")
1964
- self.canvas.unbind("<ButtonRelease-1>")
1965
- self.canvas.unbind("<ButtonRelease-3>")
1966
- self.erase_btn.config(text="Erase ON")
1967
- self.canvas.bind("<Button-1>", self.erase_object)
1968
- self.drawing = False
1969
- self.magic_wand_active = False
1970
- self.brush_active = False
1971
- self.draw_btn.config(text="Draw")
1972
- self.magic_wand_btn.config(text="Magic Wand")
1973
- else:
1974
- self.erase_active = False
1975
- self.erase_btn.config(text="Erase")
1976
- self.canvas.unbind("<Button-1>")
1977
-
1978
- ####################################################################################################
1979
- # Mode functions#
1980
- ####################################################################################################
1981
-
1982
- def apply_brush_release(self, event):
1983
- if hasattr(self, 'brush_path'):
1984
- for x, y, brush_size in self.brush_path:
1985
- img_x, img_y = (x, y) if self.zoom_active else self.canvas_to_image(x, y)
1986
- x0 = max(img_x - brush_size // 2, 0)
1987
- y0 = max(img_y - brush_size // 2, 0)
1988
- x1 = min(img_x + brush_size // 2, self.zoom_mask.shape[1] if self.zoom_active else self.mask.shape[1])
1989
- y1 = min(img_y + brush_size // 2, self.zoom_mask.shape[0] if self.zoom_active else self.mask.shape[0])
1990
- if self.zoom_active:
1991
- self.zoom_mask[y0:y1, x0:x1] = 255
1992
- self.update_original_mask_from_zoom()
1993
- else:
1994
- self.mask[y0:y1, x0:x1] = 255
1995
- del self.brush_path
1996
- self.canvas.delete("temp_line")
1997
- self.update_display()
1998
-
1999
- def erase_brush_release(self, event):
2000
- if hasattr(self, 'erase_path'):
2001
- for x, y, brush_size in self.erase_path:
2002
- img_x, img_y = (x, y) if self.zoom_active else self.canvas_to_image(x, y)
2003
- x0 = max(img_x - brush_size // 2, 0)
2004
- y0 = max(img_y - brush_size // 2, 0)
2005
- x1 = min(img_x + brush_size // 2, self.zoom_mask.shape[1] if self.zoom_active else self.mask.shape[1])
2006
- y1 = min(img_y + brush_size // 2, self.zoom_mask.shape[0] if self.zoom_active else self.mask.shape[0])
2007
- if self.zoom_active:
2008
- self.zoom_mask[y0:y1, x0:x1] = 0
2009
- self.update_original_mask_from_zoom()
2010
- else:
2011
- self.mask[y0:y1, x0:x1] = 0
2012
- del self.erase_path
2013
- self.canvas.delete("temp_line")
2014
- self.update_display()
2015
-
2016
- def apply_brush(self, event):
2017
- brush_size = int(self.brush_size_entry.get())
2018
- x, y = event.x, event.y
2019
- if not hasattr(self, 'brush_path'):
2020
- self.brush_path = []
2021
- self.last_brush_coord = (x, y)
2022
- if self.last_brush_coord:
2023
- last_x, last_y = self.last_brush_coord
2024
- rr, cc = line(last_y, last_x, y, x)
2025
- for ry, rx in zip(rr, cc):
2026
- self.brush_path.append((rx, ry, brush_size))
2027
-
2028
- self.canvas.create_line(self.last_brush_coord[0], self.last_brush_coord[1], x, y, width=brush_size, fill="blue", tag="temp_line")
2029
- self.last_brush_coord = (x, y)
2030
-
2031
- def erase_brush(self, event):
2032
- brush_size = int(self.brush_size_entry.get())
2033
- x, y = event.x, event.y
2034
- if not hasattr(self, 'erase_path'):
2035
- self.erase_path = []
2036
- self.last_erase_coord = (x, y)
2037
- if self.last_erase_coord:
2038
- last_x, last_y = self.last_erase_coord
2039
- rr, cc = line(last_y, last_x, y, x)
2040
- for ry, rx in zip(rr, cc):
2041
- self.erase_path.append((rx, ry, brush_size))
2042
-
2043
- self.canvas.create_line(self.last_erase_coord[0], self.last_erase_coord[1], x, y, width=brush_size, fill="white", tag="temp_line")
2044
- self.last_erase_coord = (x, y)
2045
-
2046
- def erase_object(self, event):
2047
- x, y = event.x, event.y
2048
- if self.zoom_active:
2049
- canvas_x, canvas_y = x, y
2050
- zoomed_x = int(canvas_x * (self.zoom_image.shape[1] / self.canvas_width))
2051
- zoomed_y = int(canvas_y * (self.zoom_image.shape[0] / self.canvas_height))
2052
- orig_x = int(zoomed_x * ((self.zoom_x1 - self.zoom_x0) / self.canvas_width) + self.zoom_x0)
2053
- orig_y = int(zoomed_y * ((self.zoom_y1 - self.zoom_y0) / self.canvas_height) + self.zoom_y0)
2054
- if orig_x < 0 or orig_y < 0 or orig_x >= self.image.shape[1] or orig_y >= self.image.shape[0]:
2055
- print("Point is out of bounds in the original image.")
2056
- return
2057
- else:
2058
- orig_x, orig_y = x, y
2059
- label_to_remove = self.mask[orig_y, orig_x]
2060
- if label_to_remove > 0:
2061
- self.mask[self.mask == label_to_remove] = 0
2062
- self.update_display()
2063
-
2064
- def use_magic_wand(self, event):
2065
- x, y = event.x, event.y
2066
- tolerance = int(self.magic_wand_tolerance.get())
2067
- maximum = int(self.max_pixels_entry.get())
2068
- action = 'add' if event.num == 1 else 'erase'
2069
- if self.zoom_active:
2070
- self.magic_wand_zoomed((x, y), tolerance, action)
2071
- else:
2072
- self.magic_wand_normal((x, y), tolerance, action)
2073
-
2074
- def apply_magic_wand(self, image, mask, seed_point, tolerance, maximum, action='add'):
2075
- x, y = seed_point
2076
- initial_value = image[y, x].astype(np.float32)
2077
- visited = np.zeros_like(image, dtype=bool)
2078
- queue = deque([(x, y)])
2079
- added_pixels = 0
2080
-
2081
- while queue and added_pixels < maximum:
2082
- cx, cy = queue.popleft()
2083
- if visited[cy, cx]:
2084
- continue
2085
- visited[cy, cx] = True
2086
- current_value = image[cy, cx].astype(np.float32)
2087
-
2088
- if np.linalg.norm(abs(current_value - initial_value)) <= tolerance:
2089
- if mask[cy, cx] == 0:
2090
- added_pixels += 1
2091
- mask[cy, cx] = 255 if action == 'add' else 0
2092
-
2093
- if added_pixels >= maximum:
2094
- break
2095
-
2096
- for dx, dy in [(-1, 0), (1, 0), (0, -1), (0, 1)]:
2097
- nx, ny = cx + dx, cy + dy
2098
- if 0 <= nx < image.shape[1] and 0 <= ny < image.shape[0] and not visited[ny, nx]:
2099
- queue.append((nx, ny))
2100
- return mask
2101
-
2102
- def magic_wand_normal(self, seed_point, tolerance, action):
2103
- try:
2104
- maximum = int(self.max_pixels_entry.get())
2105
- except ValueError:
2106
- print("Invalid maximum value; using default of 1000")
2107
- maximum = 1000
2108
- self.mask = self.apply_magic_wand(self.image, self.mask, seed_point, tolerance, maximum, action)
2109
- self.display_image()
2110
-
2111
- def magic_wand_zoomed(self, seed_point, tolerance, action):
2112
- if self.zoom_image_orig is None or self.zoom_mask is None:
2113
- print("Zoomed image or mask not initialized")
2114
- return
2115
- try:
2116
- maximum = int(self.max_pixels_entry.get())
2117
- maximum = maximum * self.zoom_scale
2118
- except ValueError:
2119
- print("Invalid maximum value; using default of 1000")
2120
- maximum = 1000
2121
-
2122
- canvas_x, canvas_y = seed_point
2123
- if canvas_x < 0 or canvas_y < 0 or canvas_x >= self.zoom_image_orig.shape[1] or canvas_y >= self.zoom_image_orig.shape[0]:
2124
- print("Selected point is out of bounds in the zoomed image.")
2125
- return
2126
-
2127
- self.zoom_mask = self.apply_magic_wand(self.zoom_image_orig, self.zoom_mask, (canvas_x, canvas_y), tolerance, maximum, action)
2128
- y0, y1, x0, x1 = self.zoom_y0, self.zoom_y1, self.zoom_x0, self.zoom_x1
2129
- zoomed_mask_resized_back = resize(self.zoom_mask, (y1 - y0, x1 - x0), order=0, preserve_range=True).astype(np.uint8)
2130
- if action == 'erase':
2131
- self.mask[y0:y1, x0:x1] = np.where(zoomed_mask_resized_back == 0, 0, self.mask[y0:y1, x0:x1])
2132
- else:
2133
- self.mask[y0:y1, x0:x1] = np.where(zoomed_mask_resized_back > 0, zoomed_mask_resized_back, self.mask[y0:y1, x0:x1])
2134
- self.update_display()
2135
-
2136
- def draw(self, event):
2137
- if self.drawing:
2138
- x, y = event.x, event.y
2139
- if self.draw_coordinates:
2140
- last_x, last_y = self.draw_coordinates[-1]
2141
- self.current_line = self.canvas.create_line(last_x, last_y, x, y, fill="yellow", width=3)
2142
- self.draw_coordinates.append((x, y))
2143
-
2144
- def draw_on_zoomed_mask(self, draw_coordinates):
2145
- canvas_height = self.canvas.winfo_height()
2146
- canvas_width = self.canvas.winfo_width()
2147
- zoomed_mask = np.zeros((canvas_height, canvas_width), dtype=np.uint8)
2148
- rr, cc = polygon(np.array(draw_coordinates)[:, 1], np.array(draw_coordinates)[:, 0], shape=zoomed_mask.shape)
2149
- zoomed_mask[rr, cc] = 255
2150
- return zoomed_mask
2151
-
2152
- def finish_drawing(self, event):
2153
- if len(self.draw_coordinates) > 2:
2154
- self.draw_coordinates.append(self.draw_coordinates[0])
2155
- if self.zoom_active:
2156
- x0, x1, y0, y1 = self.zoom_x0, self.zoom_x1, self.zoom_y0, self.zoom_y1
2157
- zoomed_mask = self.draw_on_zoomed_mask(self.draw_coordinates)
2158
- self.update_original_mask(zoomed_mask, x0, x1, y0, y1)
2159
- else:
2160
- rr, cc = polygon(np.array(self.draw_coordinates)[:, 1], np.array(self.draw_coordinates)[:, 0], shape=self.mask.shape)
2161
- self.mask[rr, cc] = np.maximum(self.mask[rr, cc], 255)
2162
- self.mask = self.mask.copy()
2163
- self.canvas.delete(self.current_line)
2164
- self.draw_coordinates.clear()
2165
- self.update_display()
2166
-
2167
- def finish_drawing_if_active(self, event):
2168
- if self.drawing and len(self.draw_coordinates) > 2:
2169
- self.finish_drawing(event)
2170
-
2171
- ####################################################################################################
2172
- # Single function butons#
2173
- ####################################################################################################
2174
-
2175
- def apply_normalization(self):
2176
- self.lower_quantile.set(self.lower_entry.get())
2177
- self.upper_quantile.set(self.upper_entry.get())
2178
- self.update_display()
2179
-
2180
- def fill_objects(self):
2181
- binary_mask = self.mask > 0
2182
- filled_mask = binary_fill_holes(binary_mask)
2183
- self.mask = filled_mask.astype(np.uint8) * 255
2184
- labeled_mask, _ = label(filled_mask)
2185
- self.mask = labeled_mask
2186
- self.update_display()
2187
-
2188
- def relabel_objects(self):
2189
- mask = self.mask
2190
- labeled_mask, num_labels = label(mask > 0)
2191
- self.mask = labeled_mask
2192
- self.update_display()
2193
-
2194
- def clear_objects(self):
2195
- self.mask = np.zeros_like(self.mask)
2196
- self.update_display()
2197
-
2198
- def invert_mask(self):
2199
- self.mask = np.where(self.mask > 0, 0, 1)
2200
- self.relabel_objects()
2201
- self.update_display()
2202
-
2203
- def remove_small_objects(self):
2204
- try:
2205
- min_area = int(self.min_area_entry.get())
2206
- except ValueError:
2207
- print("Invalid minimum area value; using default of 100")
2208
- min_area = 100
2209
-
2210
- labeled_mask, num_labels = label(self.mask > 0)
2211
- for i in range(1, num_labels + 1): # Skip background
2212
- if np.sum(labeled_mask == i) < min_area:
2213
- self.mask[labeled_mask == i] = 0 # Remove small objects
2214
- self.update_display()
2215
-
2216
- class AnnotateApp:
2217
- def __init__(self, root, db_path, src, image_type=None, channels=None, image_size=200, annotation_column='annotate', normalize=False, percentiles=(1, 99), measurement=None, threshold=None, normalize_channels=None, outline=None, outline_threshold_factor=1, outline_sigma=1):
2218
- self.root = root
2219
- self.db_path = db_path
2220
- self.src = src
2221
- self.index = 0
2222
-
2223
- if isinstance(image_size, list):
2224
- self.image_size = (int(image_size[0]), int(image_size[0]))
2225
- elif isinstance(image_size, int):
2226
- self.image_size = (image_size, image_size)
2227
- else:
2228
- raise ValueError("Invalid image size")
2229
-
2230
- self.orig_annotation_columns = annotation_column
2231
- self.annotation_column = annotation_column
2232
- self.image_type = image_type
2233
- self.channels = channels
2234
- self.normalize = normalize
2235
- self.percentiles = percentiles
2236
- self.images = {}
2237
- self.pending_updates = {}
2238
- self.labels = []
2239
- self.adjusted_to_original_paths = {}
2240
- self.terminate = False
2241
- self.update_queue = Queue()
2242
- self.measurement = measurement
2243
- self.threshold = threshold
2244
- self.normalize_channels = normalize_channels
2245
- self.outline = outline #([s.strip().lower() for s in outline.split(',') if s.strip()]if isinstance(outline, str) and outline else None)
2246
- self.outline_threshold_factor = outline_threshold_factor
2247
- self.outline_sigma = outline_sigma
2248
-
2249
- style_out = set_dark_style(ttk.Style())
2250
- self.font_loader = style_out['font_loader']
2251
- self.font_size = style_out['font_size']
2252
- self.bg_color = style_out['bg_color']
2253
- self.fg_color = style_out['fg_color']
2254
- self.active_color = style_out['active_color']
2255
- self.inactive_color = style_out['inactive_color']
2256
-
2257
- if self.font_loader:
2258
- self.font_style = self.font_loader.get_font(size=self.font_size)
2259
- else:
2260
- self.font_style = ("Arial", 12)
2261
-
2262
- self.root.configure(bg=style_out['inactive_color'])
2263
-
2264
- self.filtered_paths_annotations = []
2265
- self.prefilter_paths_annotations()
2266
-
2267
- self.db_update_thread = threading.Thread(target=self.update_database_worker)
2268
- self.db_update_thread.start()
2269
-
2270
- # Set the initial window size and make it fit the screen size
2271
- self.root.geometry(f"{self.root.winfo_screenwidth()}x{self.root.winfo_screenheight()}")
2272
- self.root.update_idletasks()
2273
-
2274
- # Create the status label
2275
- self.status_label = Label(root, text="", font=self.font_style, bg=self.root.cget('bg'))
2276
- self.status_label.grid(row=2, column=0, padx=10, pady=10, sticky="w")
2277
-
2278
- # Place the buttons at the bottom right
2279
- self.button_frame = Frame(root, bg=self.root.cget('bg'))
2280
- self.button_frame.grid(row=2, column=1, padx=10, pady=10, sticky="se")
2281
-
2282
- self.next_button = Button(self.button_frame, text="Next", command=self.next_page, bg=self.bg_color, fg=self.fg_color, highlightbackground=self.fg_color, highlightcolor=self.fg_color, highlightthickness=1)
2283
- self.next_button.pack(side="right", padx=5)
2284
-
2285
- self.previous_button = Button(self.button_frame, text="Back", command=self.previous_page, bg=self.bg_color, fg=self.fg_color, highlightbackground=self.fg_color, highlightcolor=self.fg_color, highlightthickness=1)
2286
- self.previous_button.pack(side="right", padx=5)
2287
-
2288
- self.exit_button = Button(self.button_frame, text="Exit", command=self.shutdown, bg=self.bg_color, fg=self.fg_color, highlightbackground=self.fg_color, highlightcolor=self.fg_color, highlightthickness=1)
2289
- self.exit_button.pack(side="right", padx=5)
2290
-
2291
- self.train_button = Button(self.button_frame,text="Train & Classify (beta)",command=self.train_and_classify,bg=self.bg_color,fg=self.fg_color,highlightbackground=self.fg_color,highlightcolor=self.fg_color,highlightthickness=1)
2292
- self.train_button.pack(side="right", padx=5)
2293
-
2294
- self.train_button = Button(self.button_frame,text="orig.",command=self.swich_back_annotation_column,bg=self.bg_color,fg=self.fg_color,highlightbackground=self.fg_color,highlightcolor=self.fg_color,highlightthickness=1)
2295
- self.train_button.pack(side="right", padx=5)
2296
-
2297
- self.settings_button = Button(self.button_frame, text="Settings", command=self.open_settings_window, bg=self.bg_color, fg=self.fg_color, highlightbackground=self.fg_color,highlightcolor=self.fg_color,highlightthickness=1)
2298
- self.settings_button.pack(side="right", padx=5)
2299
-
2300
- # Calculate grid rows and columns based on the root window size and image size
2301
- self.calculate_grid_dimensions()
2302
-
2303
- # Create a frame to hold the image grid
2304
- self.grid_frame = Frame(root, bg=self.root.cget('bg'))
2305
- self.grid_frame.grid(row=0, column=0, columnspan=2, padx=0, pady=0, sticky="nsew")
2306
-
2307
- for i in range(self.grid_rows * self.grid_cols):
2308
- label = Label(self.grid_frame, bg=self.root.cget('bg'))
2309
- label.grid(row=i // self.grid_cols, column=i % self.grid_cols, padx=2, pady=2, sticky="nsew")
2310
- self.labels.append(label)
2311
-
2312
- # Make the grid frame resize with the window
2313
- self.root.grid_rowconfigure(0, weight=1)
2314
- self.root.grid_columnconfigure(0, weight=1)
2315
- self.root.grid_columnconfigure(1, weight=1)
2316
-
2317
- for row in range(self.grid_rows):
2318
- self.grid_frame.grid_rowconfigure(row, weight=1)
2319
- for col in range(self.grid_cols):
2320
- self.grid_frame.grid_columnconfigure(col, weight=1)
2321
-
2322
- def open_settings_window(self):
2323
- from .gui_utils import generate_annotate_fields, convert_to_number
2324
-
2325
- # Create settings window
2326
- settings_window = tk.Toplevel(self.root)
2327
- settings_window.title("Modify Annotation Settings")
2328
-
2329
- style_out = set_dark_style(ttk.Style())
2330
- settings_window.configure(bg=style_out['bg_color'])
2331
-
2332
- settings_frame = tk.Frame(settings_window, bg=style_out['bg_color'])
2333
- settings_frame.pack(fill=tk.BOTH, expand=True)
2334
-
2335
- # Generate fields with current settings pre-filled
2336
- vars_dict = generate_annotate_fields(settings_frame)
2337
-
2338
- # Pre-fill the current settings into vars_dict
2339
- current_settings = {
2340
- 'image_type': self.image_type or '',
2341
- 'channels': ','.join(self.channels) if self.channels else '',
2342
- 'img_size': f"{self.image_size[0]},{self.image_size[1]}",
2343
- 'annotation_column': self.annotation_column or '',
2344
- 'normalize': str(self.normalize),
2345
- 'percentiles': ','.join(map(str, self.percentiles)),
2346
- 'measurement': ','.join(self.measurement) if self.measurement else '',
2347
- 'threshold': str(self.threshold) if self.threshold is not None else '',
2348
- 'normalize_channels': ','.join(self.normalize_channels) if self.normalize_channels else '',
2349
- 'outline': ','.join(self.outline) if self.outline else '',
2350
- 'outline_threshold_factor': str(self.outline_threshold_factor) if hasattr(self, 'outline_threshold_factor') else '1.0',
2351
- 'outline_sigma': str(self.outline_sigma) if hasattr(self, 'outline_sigma') else '1.0',
2352
- 'src': self.src,
2353
- 'db_path': self.db_path,
2354
- }
2355
-
2356
- for key, data in vars_dict.items():
2357
- if key in current_settings:
2358
- data['entry'].delete(0, tk.END)
2359
- data['entry'].insert(0, current_settings[key])
2360
-
2361
- def apply_new_settings():
2362
- settings = {key: data['entry'].get() for key, data in vars_dict.items()}
2363
-
2364
- # Process settings exactly as your original initiation function does
2365
- settings['channels'] = settings['channels'].split(',') if settings['channels'] else None
2366
- settings['img_size'] = list(map(int, settings['img_size'].split(',')))
2367
- settings['percentiles'] = list(map(convert_to_number, settings['percentiles'].split(','))) if settings['percentiles'] else [1, 99]
2368
- settings['normalize'] = settings['normalize'].lower() == 'true'
2369
- settings['normalize_channels'] = settings['normalize_channels'].split(',') if settings['normalize_channels'] else None
2370
- settings['outline'] = settings['outline'].split(',') if settings['outline'] else None
2371
- settings['outline_threshold_factor'] = float(settings['outline_threshold_factor'].replace(',', '.')) if settings['outline_threshold_factor'] else 1.0
2372
- settings['outline_sigma'] = float(settings['outline_sigma'].replace(',', '.')) if settings['outline_sigma'] else 1.0
2373
-
2374
- try:
2375
- settings['measurement'] = settings['measurement'].split(',') if settings['measurement'] else None
2376
- settings['threshold'] = None if settings['threshold'].lower() == 'none' else int(settings['threshold'])
2377
- except:
2378
- settings['measurement'] = None
2379
- settings['threshold'] = None
2380
-
2381
- # Convert empty strings to None
2382
- for key, value in settings.items():
2383
- if isinstance(value, list):
2384
- settings[key] = [v if v != '' else None for v in value]
2385
- elif value == '':
2386
- settings[key] = None
2387
-
2388
- # Apply these settings dynamically using update_settings method
2389
- self.update_settings(**{
2390
- 'image_type': settings.get('image_type'),
2391
- 'channels': settings.get('channels'),
2392
- 'image_size': settings.get('img_size'),
2393
- 'annotation_column': settings.get('annotation_column'),
2394
- 'normalize': settings.get('normalize'),
2395
- 'percentiles': settings.get('percentiles'),
2396
- 'measurement': settings.get('measurement'),
2397
- 'threshold': settings.get('threshold'),
2398
- 'normalize_channels': settings.get('normalize_channels'),
2399
- 'outline': settings.get('outline'),
2400
- 'outline_threshold_factor': settings.get('outline_threshold_factor'),
2401
- 'outline_sigma': settings.get('outline_sigma'),
2402
- 'src': self.src,
2403
- 'db_path': self.db_path
2404
- })
2405
-
2406
- settings_window.destroy()
2407
-
2408
- apply_button = spacrButton(settings_window, text="Apply Settings", command=apply_new_settings,show_text=False)
2409
- apply_button.pack(pady=10)
2410
-
2411
- def update_settings(self, **kwargs):
2412
- allowed_attributes = {
2413
- 'image_type', 'channels', 'image_size', 'annotation_column', 'src', 'db_path',
2414
- 'normalize', 'percentiles', 'measurement', 'threshold', 'normalize_channels', 'outline', 'outline_threshold_factor', 'outline_sigma'
2415
- }
2416
-
2417
- updated = False
2418
-
2419
- for attr, value in kwargs.items():
2420
- if attr in allowed_attributes and value is not None:
2421
- if attr == 'outline':
2422
- if isinstance(value, str):
2423
- value = [s.strip().lower() for s in value.split(',') if s.strip()]
2424
- elif attr == 'outline_threshold_factor':
2425
- value = float(value)
2426
- elif attr == 'outline_sigma':
2427
- value = float(value)
2428
- setattr(self, attr, value)
2429
- updated = True
2430
-
2431
-
2432
- if 'image_size' in kwargs:
2433
- if isinstance(self.image_size, list):
2434
- self.image_size = (int(self.image_size[0]), int(self.image_size[0]))
2435
- elif isinstance(self.image_size, int):
2436
- self.image_size = (self.image_size, self.image_size)
2437
- elif isinstance(self.image_size, tuple) and len(self.image_size) == 2:
2438
- self.image_size = tuple(map(int, self.image_size))
2439
- else:
2440
- raise ValueError("Invalid image size")
2441
-
2442
- self.calculate_grid_dimensions()
2443
- self.recreate_image_grid()
2444
-
2445
- if updated:
2446
- current_index = self.index # Retain current index
2447
- self.prefilter_paths_annotations()
2448
-
2449
- # Ensure the retained index is still valid (not out of bounds)
2450
- max_index = len(self.filtered_paths_annotations) - 1
2451
- self.index = min(current_index, max_index := max(0, max(0, max(len(self.filtered_paths_annotations) - self.grid_rows * self.grid_cols, 0))))
2452
- self.load_images()
2453
-
2454
- def recreate_image_grid(self):
2455
- # Remove current labels
2456
- for label in self.labels:
2457
- label.destroy()
2458
- self.labels.clear()
2459
-
2460
- # Recreate the labels grid with updated dimensions
2461
- for i in range(self.grid_rows * self.grid_cols):
2462
- label = Label(self.grid_frame, bg=self.root.cget('bg'))
2463
- label.grid(row=i // self.grid_cols, column=i % self.grid_cols, padx=2, pady=2, sticky="nsew")
2464
- self.labels.append(label)
2465
-
2466
- # Reconfigure grid weights
2467
- for row in range(self.grid_rows):
2468
- self.grid_frame.grid_rowconfigure(row, weight=1)
2469
- for col in range(self.grid_cols):
2470
- self.grid_frame.grid_columnconfigure(col, weight=1)
2471
-
2472
-
2473
- def swich_back_annotation_column(self):
2474
- self.annotation_column = self.orig_annotation_columns
2475
- self.prefilter_paths_annotations()
2476
- self.update_display()
2477
-
2478
- def calculate_grid_dimensions(self):
2479
- window_width = self.root.winfo_width()
2480
- window_height = self.root.winfo_height()
2481
-
2482
- self.grid_cols = window_width // (self.image_size[0] + 4)
2483
- self.grid_rows = (window_height - self.button_frame.winfo_height() - 4) // (self.image_size[1] + 4)
2484
-
2485
- # Update to make sure grid_rows and grid_cols are at least 1
2486
- self.grid_cols = max(1, self.grid_cols)
2487
- self.grid_rows = max(1, self.grid_rows)
2488
-
2489
- def prefilter_paths_annotations(self):
2490
- from .io import _read_and_join_tables, _read_db
2491
- from .utils import is_list_of_lists
2492
-
2493
- if self.measurement and self.threshold is not None:
2494
- df = _read_and_join_tables(self.db_path)
2495
- png_list_df = _read_db(self.db_path, tables=['png_list'])[0]
2496
- png_list_df = png_list_df.set_index('prcfo')
2497
- df = df.merge(png_list_df, left_index=True, right_index=True)
2498
- df[self.annotation_column] = None
2499
- before = len(df)
2500
-
2501
- if isinstance(self.threshold, int):
2502
- if isinstance(self.measurement, list):
2503
- mes = self.measurement[0]
2504
- if isinstance(self.measurement, str):
2505
- mes = self.measurement
2506
- df = df[df[f'{mes}'] == self.threshold]
2507
-
2508
- if is_list_of_lists(self.measurement):
2509
- if isinstance(self.threshold, list) or is_list_of_lists(self.threshold):
2510
- if len(self.measurement) == len(self.threshold):
2511
- for idx, var in enumerate(self.measurement):
2512
- df = df[df[var[idx]] > self.threshold[idx]]
2513
- after = len(df)
2514
- elif len(self.measurement) == len(self.threshold)*2:
2515
- th_idx = 0
2516
- for idx, var in enumerate(self.measurement):
2517
- if idx % 2 != 0:
2518
- th_idx += 1
2519
- thd = self.threshold
2520
- if isinstance(thd, list):
2521
- thd = thd[0]
2522
- df[f'threshold_measurement_{idx}'] = df[self.measurement[idx]]/df[self.measurement[idx+1]]
2523
- print(f"mean threshold_measurement_{idx}: {np.mean(df['threshold_measurement'])}")
2524
- print(f"median threshold measurement: {np.median(df[self.measurement])}")
2525
- df = df[df[f'threshold_measurement_{idx}'] > thd]
2526
- after = len(df)
2527
-
2528
- elif isinstance(self.measurement, list):
2529
- df['threshold_measurement'] = df[self.measurement[0]]/df[self.measurement[1]]
2530
- print(f"mean threshold measurement: {np.mean(df['threshold_measurement'])}")
2531
- print(f"median threshold measurement: {np.median(df[self.measurement])}")
2532
- df = df[df['threshold_measurement'] > self.threshold]
2533
- after = len(df)
2534
- self.measurement = 'threshold_measurement'
2535
- print(f'Removed: {before-after} rows, retained {after}')
2536
-
2537
- else:
2538
- print(f"mean threshold measurement: {np.mean(df[self.measurement])}")
2539
- print(f"median threshold measurement: {np.median(df[self.measurement])}")
2540
- before = len(df)
2541
- if isinstance(self.threshold, str):
2542
- if self.threshold == 'q1':
2543
- self.threshold = df[self.measurement].quantile(0.1)
2544
- if self.threshold == 'q2':
2545
- self.threshold = df[self.measurement].quantile(0.2)
2546
- if self.threshold == 'q3':
2547
- self.threshold = df[self.measurement].quantile(0.3)
2548
- if self.threshold == 'q4':
2549
- self.threshold = df[self.measurement].quantile(0.4)
2550
- if self.threshold == 'q5':
2551
- self.threshold = df[self.measurement].quantile(0.5)
2552
- if self.threshold == 'q6':
2553
- self.threshold = df[self.measurement].quantile(0.6)
2554
- if self.threshold == 'q7':
2555
- self.threshold = df[self.measurement].quantile(0.7)
2556
- if self.threshold == 'q8':
2557
- self.threshold = df[self.measurement].quantile(0.8)
2558
- if self.threshold == 'q9':
2559
- self.threshold = df[self.measurement].quantile(0.9)
2560
- print(f"threshold: {self.threshold}")
2561
-
2562
- df = df[df[self.measurement] > self.threshold]
2563
- after = len(df)
2564
- print(f'Removed: {before-after} rows, retained {after}')
2565
-
2566
- df = df.dropna(subset=['png_path'])
2567
- if self.image_type:
2568
- before = len(df)
2569
- if isinstance(self.image_type, list):
2570
- for tpe in self.image_type:
2571
- df = df[df['png_path'].str.contains(tpe)]
2572
- else:
2573
- df = df[df['png_path'].str.contains(self.image_type)]
2574
- after = len(df)
2575
- print(f'image_type: Removed: {before-after} rows, retained {after}')
2576
-
2577
- self.filtered_paths_annotations = df[['png_path', self.annotation_column]].values.tolist()
2578
- else:
2579
- conn = sqlite3.connect(self.db_path)
2580
- c = conn.cursor()
2581
- if self.image_type:
2582
- c.execute(f"SELECT png_path, {self.annotation_column} FROM png_list WHERE png_path LIKE ?", (f"%{self.image_type}%",))
2583
- else:
2584
- c.execute(f"SELECT png_path, {self.annotation_column} FROM png_list")
2585
- self.filtered_paths_annotations = c.fetchall()
2586
- conn.close()
2587
-
2588
- def load_images(self):
2589
- for label in self.labels:
2590
- label.config(image='')
2591
-
2592
- self.images = {}
2593
- paths_annotations = self.filtered_paths_annotations[self.index:self.index + self.grid_rows * self.grid_cols]
2594
-
2595
- adjusted_paths = []
2596
- for path, annotation in paths_annotations:
2597
- if not path.startswith(self.src):
2598
- parts = path.split('/data/')
2599
- if len(parts) > 1:
2600
- new_path = os.path.join(self.src, 'data', parts[1])
2601
- self.adjusted_to_original_paths[new_path] = path
2602
- adjusted_paths.append((new_path, annotation))
2603
- else:
2604
- adjusted_paths.append((path, annotation))
2605
- else:
2606
- adjusted_paths.append((path, annotation))
2607
-
2608
- with ThreadPoolExecutor() as executor:
2609
- loaded_images = list(executor.map(self.load_single_image, adjusted_paths))
2610
-
2611
- for i, (img, annotation) in enumerate(loaded_images):
2612
- if annotation:
2613
- border_color = self.active_color if annotation == 1 else 'red'
2614
- img = self.add_colored_border(img, border_width=5, border_color=border_color)
2615
-
2616
- photo = ImageTk.PhotoImage(img)
2617
- label = self.labels[i]
2618
- self.images[label] = photo
2619
- label.config(image=photo)
2620
-
2621
- path = adjusted_paths[i][0]
2622
- label.bind('<Button-1>', self.get_on_image_click(path, label, img))
2623
- label.bind('<Button-3>', self.get_on_image_click(path, label, img))
2624
-
2625
- self.root.update()
2626
-
2627
- def load_single_image(self, path_annotation_tuple):
2628
- path, annotation = path_annotation_tuple
2629
- img = Image.open(path)
2630
- img = self.normalize_image(img, self.normalize, self.percentiles, self.normalize_channels)
2631
- img = img.convert('RGB')
2632
- img = self.filter_channels(img)
2633
-
2634
- if self.outline:
2635
- img = self.outline_image(img, self.outline_sigma)
2636
-
2637
- img = img.resize(self.image_size)
2638
- return img, annotation
2639
-
2640
- def outline_image(self, img, edge_sigma=1, edge_thickness=1):
2641
- """
2642
- For each selected channel, compute a continuous outline from the intensity landscape
2643
- using Otsu threshold scaled by a correction factor. Replace only that channel.
2644
- """
2645
- arr = np.asarray(img)
2646
- if arr.ndim != 3 or arr.shape[2] != 3:
2647
- return img # not RGB
2648
-
2649
- out_img = arr.copy()
2650
- channel_map = {'r': 0, 'g': 1, 'b': 2}
2651
- factor = getattr(self, 'outline_threshold_factor', 1.0)
2652
-
2653
- for ch in self.outline:
2654
- if ch not in channel_map:
2655
- continue
2656
- idx = channel_map[ch]
2657
- channel_data = arr[:, :, idx]
2658
-
2659
- try:
2660
- channel_data = gaussian_filter(channel_data, sigma=edge_sigma)
2661
- otsu_thresh = threshold_otsu(channel_data)
2662
- corrected_thresh = min(255, otsu_thresh * factor)
2663
- fg_mask = channel_data > corrected_thresh
2664
- except Exception:
2665
- continue
2666
-
2667
- edge = find_boundaries(fg_mask, mode='inner')
2668
- thick_edge = dilation(edge, disk(edge_thickness))
2669
-
2670
- out_img[:, :, idx] = (thick_edge * 255).astype(np.uint8)
2671
-
2672
- return Image.fromarray(out_img)
2673
-
2674
- @staticmethod
2675
- def normalize_image(img, normalize=False, percentiles=(1, 99), normalize_channels=None):
2676
- """
2677
- Normalize an image based on specific channels (R, G, B).
2678
-
2679
- Args:
2680
- img (PIL.Image or np.array): Input image.
2681
- normalize (bool): Whether to normalize the image or not.
2682
- percentiles (tuple): Percentiles to use for intensity rescaling.
2683
- normalize_channels (list): List of channels to normalize. E.g., ['r', 'g', 'b'], ['r'], ['g'], etc.
2684
-
2685
- Returns:
2686
- PIL.Image: Normalized image.
2687
- """
2688
- img_array = np.array(img)
2689
-
2690
- if normalize:
2691
- if img_array.ndim == 2: # Grayscale image
2692
- p2, p98 = np.percentile(img_array, percentiles)
2693
- img_array = rescale_intensity(img_array, in_range=(p2, p98), out_range=(0, 255))
2694
- else: # Color image or multi-channel image
2695
- # Create a map for the color channels
2696
- channel_map = {'r': 0, 'g': 1, 'b': 2}
2697
-
2698
- # If normalize_channels is not specified, normalize all channels
2699
- if normalize_channels is None:
2700
- normalize_channels = ['r', 'g', 'b']
2701
-
2702
- for channel_name in normalize_channels:
2703
- if channel_name in channel_map:
2704
- channel_idx = channel_map[channel_name]
2705
- p2, p98 = np.percentile(img_array[:, :, channel_idx], percentiles)
2706
- img_array[:, :, channel_idx] = rescale_intensity(img_array[:, :, channel_idx], in_range=(p2, p98), out_range=(0, 255))
2707
-
2708
- img_array = np.clip(img_array, 0, 255).astype('uint8')
2709
-
2710
- return Image.fromarray(img_array)
2711
-
2712
-
2713
- def add_colored_border(self, img, border_width, border_color):
2714
- top_border = Image.new('RGB', (img.width, border_width), color=border_color)
2715
- bottom_border = Image.new('RGB', (img.width, border_width), color=border_color)
2716
- left_border = Image.new('RGB', (border_width, img.height), color=border_color)
2717
- right_border = Image.new('RGB', (border_width, img.height), color=border_color)
2718
-
2719
- bordered_img = Image.new('RGB', (img.width + 2 * border_width, img.height + 2 * border_width), color=self.fg_color)
2720
- bordered_img.paste(top_border, (border_width, 0))
2721
- bordered_img.paste(bottom_border, (border_width, img.height + border_width))
2722
- bordered_img.paste(left_border, (0, border_width))
2723
- bordered_img.paste(right_border, (img.width + border_width, border_width))
2724
- bordered_img.paste(img, (border_width, border_width))
2725
-
2726
- return bordered_img
2727
-
2728
- def filter_channels(self, img):
2729
- r, g, b = img.split()
2730
- if self.channels:
2731
- if 'r' not in self.channels:
2732
- r = r.point(lambda _: 0)
2733
- if 'g' not in self.channels:
2734
- g = g.point(lambda _: 0)
2735
- if 'b' not in self.channels:
2736
- b = b.point(lambda _: 0)
2737
-
2738
- if len(self.channels) == 1:
2739
- channel_img = r if 'r' in self.channels else (g if 'g' in self.channels else b)
2740
- return ImageOps.grayscale(channel_img)
2741
-
2742
- return Image.merge("RGB", (r, g, b))
2743
-
2744
- def get_on_image_click(self, path, label, img):
2745
- def on_image_click(event):
2746
- new_annotation = 1 if event.num == 1 else (2 if event.num == 3 else None)
2747
-
2748
- original_path = self.adjusted_to_original_paths.get(path, path)
2749
-
2750
- if original_path in self.pending_updates and self.pending_updates[original_path] == new_annotation:
2751
- self.pending_updates[original_path] = None
2752
- new_annotation = None
2753
- else:
2754
- self.pending_updates[original_path] = new_annotation
2755
-
2756
- print(f"Image {os.path.split(path)[1]} annotated: {new_annotation}")
2757
-
2758
- img_ = img.crop((5, 5, img.width-5, img.height-5))
2759
- border_fill = self.active_color if new_annotation == 1 else ('red' if new_annotation == 2 else None)
2760
- img_ = ImageOps.expand(img_, border=5, fill=border_fill) if border_fill else img_
2761
-
2762
- photo = ImageTk.PhotoImage(img_)
2763
- self.images[label] = photo
2764
- label.config(image=photo)
2765
- self.root.update()
2766
-
2767
- return on_image_click
2768
-
2769
- @staticmethod
2770
- def update_html(text):
2771
- display(HTML(f"""
2772
- <script>
2773
- document.getElementById('unique_id').innerHTML = '{text}';
2774
- </script>
2775
- """))
2776
-
2777
- def update_database_worker(self):
2778
- conn = sqlite3.connect(self.db_path)
2779
- c = conn.cursor()
2780
-
2781
- display(HTML("<div id='unique_id'>Initial Text</div>"))
2782
-
2783
- while True:
2784
- if self.terminate:
2785
- conn.close()
2786
- break
2787
-
2788
- if not self.update_queue.empty():
2789
- AnnotateApp.update_html("Do not exit, Updating database...")
2790
- self.status_label.config(text='Do not exit, Updating database...')
2791
-
2792
- pending_updates = self.update_queue.get()
2793
- for path, new_annotation in pending_updates.items():
2794
- if new_annotation is None:
2795
- c.execute(f'UPDATE png_list SET {self.annotation_column} = NULL WHERE png_path = ?', (path,))
2796
- else:
2797
- c.execute(f'UPDATE png_list SET {self.annotation_column} = ? WHERE png_path = ?', (new_annotation, path))
2798
- conn.commit()
2799
-
2800
- AnnotateApp.update_html('')
2801
- self.status_label.config(text='')
2802
- self.root.update()
2803
- time.sleep(0.1)
2804
-
2805
- def update_gui_text(self, text):
2806
- self.status_label.config(text=text)
2807
- self.root.update()
2808
-
2809
- def next_page(self):
2810
- if self.pending_updates:
2811
- self.update_queue.put(self.pending_updates.copy())
2812
- self.pending_updates.clear()
2813
- self.index += self.grid_rows * self.grid_cols
2814
- self.prefilter_paths_annotations() # Re-fetch annotations from the database
2815
- self.load_images()
2816
-
2817
- def previous_page(self):
2818
- if self.pending_updates:
2819
- self.update_queue.put(self.pending_updates.copy())
2820
- self.pending_updates.clear()
2821
- self.index -= self.grid_rows * self.grid_cols
2822
- if self.index < 0:
2823
- self.index = 0
2824
- self.prefilter_paths_annotations() # Re-fetch annotations from the database
2825
- self.load_images()
2826
-
2827
- def shutdown(self):
2828
- self.terminate = True
2829
- self.update_queue.put(self.pending_updates.copy())
2830
- if not self.pending_updates:
2831
- self.pending_updates.clear()
2832
- self.db_update_thread.join()
2833
- self.root.quit()
2834
- self.root.destroy()
2835
- print(f'Quit application')
2836
- else:
2837
- print('Waiting for pending updates to finish before quitting')
2838
-
2839
- def train_and_classify(self):
2840
- """
2841
- 1) Merge data from the relevant DB tables (including png_list).
2842
- 2) Collect manual annotations from png_list.<annotation_column> => 'manual_annotation'.
2843
- - 1 => class=1, 2 => class=0 (for training).
2844
- 3) If only one class is present, randomly sample unannotated images as the other class.
2845
- 4) Train an XGBoost model.
2846
- 5) Classify *all* rows -> fill XGboost_score (prob of class=1) & XGboost_annotation (1 or 2 if high confidence).
2847
- 6) Write those columns back to sqlite, so every row in png_list has a score (and possibly an annotation).
2848
- 7) Refresh the UI (prefilter_paths_annotations + load_images).
2849
- """
2850
-
2851
- # Optionally, update your GUI status label
2852
- self.update_gui_text("Merging data...")
2853
-
2854
- from .io import _read_and_merge_data
2855
-
2856
- # (1) Merge data
2857
- merged_df, obj_df_ls = _read_and_merge_data(
2858
- locs=[self.db_path],
2859
- tables=['cell', 'cytoplasm', 'nucleus', 'pathogen', 'png_list'],
2860
- verbose=False
2861
- )
2862
-
2863
- # (2) Load manual annotations from the DB
2864
- conn = sqlite3.connect(self.db_path)
2865
- c = conn.cursor()
2866
- c.execute(f"SELECT png_path, {self.annotation_column} FROM png_list WHERE {self.annotation_column} IS NOT NULL")
2867
- annotated_rows = c.fetchall() # e.g. [(png_path, 1 or 2), ...]
2868
- conn.close()
2869
-
2870
- # dict {png_path -> 1 or 2}
2871
- annot_dict = dict(annotated_rows)
2872
-
2873
- # Add 'manual_annotation' to merged_df
2874
- merged_df['manual_annotation'] = merged_df['png_path'].map(annot_dict)
2875
-
2876
- # Subset with manual labels
2877
- annotated_df = merged_df.dropna(subset=['manual_annotation']).copy()
2878
- # Convert "2" => "0" for binary classification
2879
- annotated_df['manual_annotation'] = annotated_df['manual_annotation'].replace({2: 0}).astype(int)
2880
-
2881
- # (3) Handle single-class scenario
2882
- class_counts = annotated_df['manual_annotation'].value_counts()
2883
- if len(class_counts) == 1:
2884
- single_class = class_counts.index[0] # 0 or 1
2885
- needed = class_counts.iloc[0]
2886
- other_class = 1 if single_class == 0 else 0
2887
-
2888
- unannotated_df_all = merged_df[merged_df['manual_annotation'].isna()].copy()
2889
- if len(unannotated_df_all) == 0:
2890
- print("No unannotated rows to sample for the other class. Cannot proceed.")
2891
- self.update_gui_text("Not enough data to train (no second class).")
2892
- return
2893
-
2894
- sample_size = min(needed, len(unannotated_df_all))
2895
- artificially_labeled = unannotated_df_all.sample(n=sample_size, replace=False).copy()
2896
- artificially_labeled['manual_annotation'] = other_class
2897
-
2898
- annotated_df = pd.concat([annotated_df, artificially_labeled], ignore_index=True)
2899
- print(f"Only one class was present => randomly labeled {sample_size} unannotated rows as {other_class}.")
2900
-
2901
- if len(annotated_df) < 2:
2902
- print("Not enough annotated data to train (need at least 2).")
2903
- self.update_gui_text("Not enough data to train.")
2904
- return
2905
-
2906
- # (4) Train XGBoost
2907
- self.update_gui_text("Training XGBoost model...")
2908
-
2909
- # Identify numeric columns
2910
- ignore_cols = {'png_path', 'manual_annotation'}
2911
- feature_cols = [
2912
- col for col in annotated_df.columns
2913
- if col not in ignore_cols
2914
- and (annotated_df[col].dtype == float or annotated_df[col].dtype == int)
2915
- ]
2916
-
2917
- X_data = annotated_df[feature_cols].fillna(0).values
2918
- y_data = annotated_df['manual_annotation'].values
2919
-
2920
- # standard train/test
2921
- X_train, X_test, y_train, y_test = train_test_split(
2922
- X_data, y_data, test_size=0.1, random_state=42
2923
- )
2924
- model = XGBClassifier(use_label_encoder=False, eval_metric='logloss')
2925
- model.fit(X_train, y_train)
2926
-
2927
- # Evaluate
2928
- preds = model.predict(X_test)
2929
- print("=== Classification Report ===")
2930
- print(classification_report(y_test, preds))
2931
- print("=== Confusion Matrix ===")
2932
- print(confusion_matrix(y_test, preds))
2933
-
2934
- # (5) Classify ALL rows
2935
- all_df = merged_df.copy()
2936
- X_all = all_df[feature_cols].fillna(0).values
2937
- probs_all = model.predict_proba(X_all)[:, 1]
2938
- # Probability => XGboost_score
2939
- all_df['XGboost_score'] = probs_all
2940
-
2941
- # Decide XGboost_annotation
2942
- def get_annotation_from_prob(prob):
2943
- if prob > 0.9:
2944
- return 1 # class=1
2945
- elif prob < 0.1:
2946
- return 0 # class=0
2947
- return None # uncertain
2948
-
2949
- xgb_anno_col = [get_annotation_from_prob(p) for p in probs_all]
2950
- # Convert 0 => 2 if your DB uses "2" for the negative class
2951
- xgb_anno_col = [2 if x == 0 else x for x in xgb_anno_col]
2952
-
2953
- all_df['XGboost_annotation'] = xgb_anno_col
2954
-
2955
- # (6) Write results back to png_list
2956
- self.update_gui_text("Updating the database with XGBoost predictions...")
2957
- conn = sqlite3.connect(self.db_path)
2958
- c = conn.cursor()
2959
- # Ensure columns exist
2960
- try:
2961
- c.execute("ALTER TABLE png_list ADD COLUMN XGboost_annotation INTEGER")
2962
- except sqlite3.OperationalError:
2963
- pass
2964
- try:
2965
- c.execute("ALTER TABLE png_list ADD COLUMN XGboost_score FLOAT")
2966
- except sqlite3.OperationalError:
2967
- pass
2968
-
2969
- # Update each row
2970
- for idx, row in all_df.iterrows():
2971
- score_val = float(row['XGboost_score'])
2972
- anno_val = row['XGboost_annotation']
2973
- the_path = row['png_path']
2974
- if pd.isna(the_path):
2975
- continue # skip if no path
2976
-
2977
- if pd.isna(anno_val):
2978
- # We set annotation=NULL but do set the score
2979
- c.execute("""
2980
- UPDATE png_list
2981
- SET XGboost_annotation = NULL,
2982
- XGboost_score = ?
2983
- WHERE png_path = ?
2984
- """, (score_val, the_path))
2985
- else:
2986
- # numeric annotation + numeric score
2987
- c.execute("""
2988
- UPDATE png_list
2989
- SET XGboost_annotation = ?,
2990
- XGboost_score = ?
2991
- WHERE png_path = ?
2992
- """, (int(anno_val), score_val, the_path))
2993
-
2994
- self.annotation_column = 'XGboost_annotation'
2995
-
2996
1322
  def standardize_figure(fig):
2997
1323
  from .gui_elements import set_dark_style
2998
1324
  from matplotlib.font_manager import FontProperties