spacr 0.3.52__py3-none-any.whl → 0.3.55__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/gui_elements.py +1 -1
- spacr/gui_utils.py +0 -111
- spacr/io.py +114 -140
- spacr/measure.py +10 -11
- spacr/ml.py +41 -32
- spacr/plot.py +24 -293
- spacr/sequencing.py +13 -9
- spacr/settings.py +15 -9
- spacr/submodules.py +19 -19
- spacr/timelapse.py +16 -16
- spacr/toxo.py +15 -15
- spacr/utils.py +72 -164
- {spacr-0.3.52.dist-info → spacr-0.3.55.dist-info}/METADATA +1 -1
- {spacr-0.3.52.dist-info → spacr-0.3.55.dist-info}/RECORD +18 -18
- {spacr-0.3.52.dist-info → spacr-0.3.55.dist-info}/LICENSE +0 -0
- {spacr-0.3.52.dist-info → spacr-0.3.55.dist-info}/WHEEL +0 -0
- {spacr-0.3.52.dist-info → spacr-0.3.55.dist-info}/entry_points.txt +0 -0
- {spacr-0.3.52.dist-info → spacr-0.3.55.dist-info}/top_level.txt +0 -0
spacr/gui_elements.py
CHANGED
@@ -706,7 +706,7 @@ class spacrProgressBar(ttk.Progressbar):
|
|
706
706
|
|
707
707
|
def set_label_position(self):
|
708
708
|
if self.label and self.progress_label:
|
709
|
-
row_info = self.grid_info().get('
|
709
|
+
row_info = self.grid_info().get('row_name', 0)
|
710
710
|
col_info = self.grid_info().get('column', 0)
|
711
711
|
col_span = self.grid_info().get('columnspan', 1)
|
712
712
|
self.progress_label.grid(row=row_info + 1, column=col_info, columnspan=col_span, pady=5, padx=5, sticky='ew')
|
spacr/gui_utils.py
CHANGED
@@ -106,32 +106,6 @@ def parse_list(value):
|
|
106
106
|
except (ValueError, SyntaxError) as e:
|
107
107
|
raise ValueError(f"Invalid format for list: {value}. Error: {e}")
|
108
108
|
|
109
|
-
def parse_list_v1(value):
|
110
|
-
"""
|
111
|
-
Parses a string representation of a list and returns the parsed list.
|
112
|
-
|
113
|
-
Args:
|
114
|
-
value (str): The string representation of the list.
|
115
|
-
|
116
|
-
Returns:
|
117
|
-
list: The parsed list, which can contain integers, floats, or strings.
|
118
|
-
|
119
|
-
Raises:
|
120
|
-
ValueError: If the input value is not a valid list format or contains mixed types or unsupported types.
|
121
|
-
"""
|
122
|
-
try:
|
123
|
-
parsed_value = ast.literal_eval(value)
|
124
|
-
if isinstance(parsed_value, list):
|
125
|
-
# Check if all elements are homogeneous (either all int, float, or str)
|
126
|
-
if all(isinstance(item, (int, float, str)) for item in parsed_value):
|
127
|
-
return parsed_value
|
128
|
-
else:
|
129
|
-
raise ValueError("List contains mixed types or unsupported types")
|
130
|
-
else:
|
131
|
-
raise ValueError(f"Expected a list but got {type(parsed_value).__name__}")
|
132
|
-
except (ValueError, SyntaxError) as e:
|
133
|
-
raise ValueError(f"Invalid format for list: {value}. Error: {e}")
|
134
|
-
|
135
109
|
# Usage example in your create_input_field function
|
136
110
|
def create_input_field(frame, label_text, row, var_type='entry', options=None, default_value=None):
|
137
111
|
"""
|
@@ -696,91 +670,6 @@ def ensure_after_tasks(frame):
|
|
696
670
|
if not hasattr(frame, 'after_tasks'):
|
697
671
|
frame.after_tasks = []
|
698
672
|
|
699
|
-
def display_gif_in_plot_frame_v1(gif_path, parent_frame):
|
700
|
-
"""Display and zoom a GIF to fill the entire parent_frame, maintaining aspect ratio, with lazy resizing and caching."""
|
701
|
-
# Clear parent_frame if it contains any previous widgets
|
702
|
-
for widget in parent_frame.winfo_children():
|
703
|
-
widget.destroy()
|
704
|
-
|
705
|
-
# Load the GIF
|
706
|
-
gif = Image.open(gif_path)
|
707
|
-
|
708
|
-
# Get the aspect ratio of the GIF
|
709
|
-
gif_width, gif_height = gif.size
|
710
|
-
gif_aspect_ratio = gif_width / gif_height
|
711
|
-
|
712
|
-
# Create a label to display the GIF and configure it to fill the parent_frame
|
713
|
-
label = tk.Label(parent_frame, bg="black")
|
714
|
-
label.grid(row=0, column=0, sticky="nsew") # Expands in all directions (north, south, east, west)
|
715
|
-
|
716
|
-
# Configure parent_frame to stretch the label to fill available space
|
717
|
-
parent_frame.grid_rowconfigure(0, weight=1)
|
718
|
-
parent_frame.grid_columnconfigure(0, weight=1)
|
719
|
-
|
720
|
-
# Cache for storing resized frames (lazily filled)
|
721
|
-
resized_frames_cache = {}
|
722
|
-
|
723
|
-
# Last frame dimensions
|
724
|
-
last_frame_width = 0
|
725
|
-
last_frame_height = 0
|
726
|
-
|
727
|
-
def resize_and_crop_frame(frame_idx, frame_width, frame_height):
|
728
|
-
"""Resize and crop the current frame of the GIF to fit the parent_frame while maintaining the aspect ratio."""
|
729
|
-
# If the frame is already cached at the current size, return it
|
730
|
-
if (frame_idx, frame_width, frame_height) in resized_frames_cache:
|
731
|
-
return resized_frames_cache[(frame_idx, frame_width, frame_height)]
|
732
|
-
|
733
|
-
# Calculate the scaling factor to zoom in on the GIF
|
734
|
-
scale_factor = max(frame_width / gif_width, frame_height / gif_height)
|
735
|
-
|
736
|
-
# Calculate new dimensions while maintaining the aspect ratio
|
737
|
-
new_width = int(gif_width * scale_factor)
|
738
|
-
new_height = int(gif_height * scale_factor)
|
739
|
-
|
740
|
-
# Resize the GIF to fit the frame
|
741
|
-
gif.seek(frame_idx)
|
742
|
-
resized_gif = gif.copy().resize((new_width, new_height), Image.Resampling.LANCZOS)
|
743
|
-
|
744
|
-
# Calculate the cropping box to center the resized GIF in the frame
|
745
|
-
crop_left = (new_width - frame_width) // 2
|
746
|
-
crop_top = (new_height - frame_height) // 2
|
747
|
-
crop_right = crop_left + frame_width
|
748
|
-
crop_bottom = crop_top + frame_height
|
749
|
-
|
750
|
-
# Crop the resized GIF to exactly fit the frame
|
751
|
-
cropped_gif = resized_gif.crop((crop_left, crop_top, crop_right, crop_bottom))
|
752
|
-
|
753
|
-
# Convert the cropped frame to a Tkinter-compatible format
|
754
|
-
frame_image = ImageTk.PhotoImage(cropped_gif)
|
755
|
-
|
756
|
-
# Cache the resized frame
|
757
|
-
resized_frames_cache[(frame_idx, frame_width, frame_height)] = frame_image
|
758
|
-
|
759
|
-
return frame_image
|
760
|
-
|
761
|
-
def update_frame(frame_idx):
|
762
|
-
"""Update the GIF frame using lazy resizing and caching."""
|
763
|
-
# Get the current size of the parent_frame
|
764
|
-
frame_width = parent_frame.winfo_width()
|
765
|
-
frame_height = parent_frame.winfo_height()
|
766
|
-
|
767
|
-
# Only resize if the frame size has changed
|
768
|
-
nonlocal last_frame_width, last_frame_height
|
769
|
-
if frame_width != last_frame_width or frame_height != last_frame_height:
|
770
|
-
last_frame_width, last_frame_height = frame_width, frame_height
|
771
|
-
|
772
|
-
# Get the resized and cropped frame image
|
773
|
-
frame_image = resize_and_crop_frame(frame_idx, frame_width, frame_height)
|
774
|
-
label.config(image=frame_image)
|
775
|
-
label.image = frame_image # Keep a reference to avoid garbage collection
|
776
|
-
|
777
|
-
# Move to the next frame, or loop back to the beginning
|
778
|
-
next_frame_idx = (frame_idx + 1) % gif.n_frames
|
779
|
-
parent_frame.after(gif.info['duration'], update_frame, next_frame_idx)
|
780
|
-
|
781
|
-
# Start the GIF animation from frame 0
|
782
|
-
update_frame(0)
|
783
|
-
|
784
673
|
def display_gif_in_plot_frame(gif_path, parent_frame):
|
785
674
|
"""Display and zoom a GIF to fill the entire parent_frame, maintaining aspect ratio, with lazy resizing and caching."""
|
786
675
|
# Clear parent_frame if it contains any previous widgets
|
spacr/io.py
CHANGED
@@ -292,121 +292,6 @@ def _load_normalized_images_and_labels(image_files, label_files, channels=None,
|
|
292
292
|
|
293
293
|
return normalized_images, labels, image_names, label_names, orig_dims
|
294
294
|
|
295
|
-
def _load_normalized_images_and_labels_v1(image_files, label_files, channels=None, percentiles=None, invert=False, visualize=False, remove_background=False, background=0, Signal_to_noise=10, target_height=None, target_width=None):
|
296
|
-
|
297
|
-
from .plot import normalize_and_visualize, plot_resize
|
298
|
-
from .utils import invert_image, apply_mask
|
299
|
-
from skimage.transform import resize as resizescikit
|
300
|
-
|
301
|
-
if isinstance(percentiles, list):
|
302
|
-
if len(percentiles) !=2:
|
303
|
-
percentiles = None
|
304
|
-
if not percentiles[0] is int:
|
305
|
-
percentiles = None
|
306
|
-
if not percentiles[1] is int:
|
307
|
-
percentiles = None
|
308
|
-
|
309
|
-
signal_thresholds = background * Signal_to_noise
|
310
|
-
lower_percentile = 2
|
311
|
-
|
312
|
-
images = []
|
313
|
-
labels = []
|
314
|
-
orig_dims = []
|
315
|
-
|
316
|
-
num_channels = 4
|
317
|
-
percentiles_1 = [[] for _ in range(num_channels)]
|
318
|
-
percentiles_99 = [[] for _ in range(num_channels)]
|
319
|
-
|
320
|
-
image_names = [os.path.basename(f) for f in image_files]
|
321
|
-
image_dir = os.path.dirname(image_files[0])
|
322
|
-
|
323
|
-
if label_files is not None:
|
324
|
-
label_names = [os.path.basename(f) for f in label_files]
|
325
|
-
label_dir = os.path.dirname(label_files[0])
|
326
|
-
|
327
|
-
# Load, normalize, and resize images
|
328
|
-
for i, img_file in enumerate(image_files):
|
329
|
-
image = cellpose.io.imread(img_file)
|
330
|
-
orig_dims.append((image.shape[0], image.shape[1]))
|
331
|
-
if invert:
|
332
|
-
image = invert_image(image)
|
333
|
-
|
334
|
-
# If specific channels are specified, select them
|
335
|
-
if channels is not None and image.ndim == 3:
|
336
|
-
image = image[..., channels]
|
337
|
-
|
338
|
-
if remove_background:
|
339
|
-
image[image < background] = 0
|
340
|
-
|
341
|
-
if image.ndim < 3:
|
342
|
-
image = np.expand_dims(image, axis=-1)
|
343
|
-
|
344
|
-
if percentiles is None:
|
345
|
-
for c in range(image.shape[-1]):
|
346
|
-
p1 = np.percentile(image[..., c], lower_percentile)
|
347
|
-
percentiles_1[c].append(p1)
|
348
|
-
for percentile in [98, 99, 99.9, 99.99, 99.999]:
|
349
|
-
p = np.percentile(image[..., c], percentile)
|
350
|
-
if p > signal_thresholds:
|
351
|
-
percentiles_99[c].append(p)
|
352
|
-
break
|
353
|
-
|
354
|
-
# Resize image
|
355
|
-
if target_height is not None and target_width is not None:
|
356
|
-
if image.ndim == 2:
|
357
|
-
image_shape = (target_height, target_width)
|
358
|
-
elif image.ndim == 3:
|
359
|
-
image_shape = (target_height, target_width, image.shape[-1])
|
360
|
-
|
361
|
-
image = resizescikit(image, image_shape, preserve_range=True, anti_aliasing=True).astype(image.dtype)
|
362
|
-
|
363
|
-
images.append(image)
|
364
|
-
|
365
|
-
if percentiles is None:
|
366
|
-
# Calculate average percentiles for normalization
|
367
|
-
avg_p1 = [np.mean(p) for p in percentiles_1]
|
368
|
-
avg_p99 = [np.mean(p) if len(p) > 0 else np.mean(percentiles_1[i]) for i, p in enumerate(percentiles_99)]
|
369
|
-
|
370
|
-
print(f'Average 1st percentiles: {avg_p1}, Average 99th percentiles: {avg_p99}')
|
371
|
-
|
372
|
-
normalized_images = []
|
373
|
-
for image in images:
|
374
|
-
normalized_image = np.zeros_like(image, dtype=np.float32)
|
375
|
-
for c in range(image.shape[-1]):
|
376
|
-
normalized_image[..., c] = rescale_intensity(image[..., c], in_range=(avg_p1[c], avg_p99[c]), out_range=(0, 1))
|
377
|
-
normalized_images.append(normalized_image)
|
378
|
-
if visualize:
|
379
|
-
normalize_and_visualize(image, normalized_image, title=f"Channel {c+1} Normalized")
|
380
|
-
else:
|
381
|
-
normalized_images = []
|
382
|
-
for image in images:
|
383
|
-
normalized_image = np.zeros_like(image, dtype=np.float32)
|
384
|
-
for c in range(image.shape[-1]):
|
385
|
-
low_p = np.percentile(image[..., c], percentiles[0])
|
386
|
-
high_p = np.percentile(image[..., c], percentiles[1])
|
387
|
-
normalized_image[..., c] = rescale_intensity(image[..., c], in_range=(low_p, high_p), out_range=(0, 1))
|
388
|
-
normalized_images.append(normalized_image)
|
389
|
-
if visualize:
|
390
|
-
normalize_and_visualize(image, normalized_image, title=f"Channel {c+1} Normalized")
|
391
|
-
|
392
|
-
if label_files is not None:
|
393
|
-
for lbl_file in label_files:
|
394
|
-
label = cellpose.io.imread(lbl_file)
|
395
|
-
# Resize label
|
396
|
-
if target_height is not None and target_width is not None:
|
397
|
-
label = resizescikit(label, (target_height, target_width), order=0, preserve_range=True, anti_aliasing=False).astype(label.dtype)
|
398
|
-
labels.append(label)
|
399
|
-
else:
|
400
|
-
label_names = []
|
401
|
-
label_dir = None
|
402
|
-
|
403
|
-
print(f'Loaded and normalized {len(normalized_images)} images and {len(labels)} labels from {image_dir} and {label_dir}')
|
404
|
-
|
405
|
-
if visualize and images and labels:
|
406
|
-
plot_resize(images, normalized_images, labels, labels)
|
407
|
-
|
408
|
-
return normalized_images, labels, image_names, label_names, orig_dims
|
409
|
-
|
410
295
|
class CombineLoaders:
|
411
296
|
|
412
297
|
"""
|
@@ -1875,6 +1760,9 @@ def _read_and_join_tables(db_path, table_names=['cell', 'cytoplasm', 'nucleus',
|
|
1875
1760
|
Returns:
|
1876
1761
|
pandas.DataFrame: The joined DataFrame containing the data from the specified tables, or None if an error occurs.
|
1877
1762
|
"""
|
1763
|
+
from .utils import rename_columns_in_db
|
1764
|
+
rename_columns_in_db(db_path)
|
1765
|
+
|
1878
1766
|
conn = sqlite3.connect(db_path)
|
1879
1767
|
dataframes = {}
|
1880
1768
|
for table_name in table_names:
|
@@ -1885,11 +1773,11 @@ def _read_and_join_tables(db_path, table_names=['cell', 'cytoplasm', 'nucleus',
|
|
1885
1773
|
print(e)
|
1886
1774
|
conn.close()
|
1887
1775
|
if 'png_list' in dataframes:
|
1888
|
-
png_list_df = dataframes['png_list'][['cell_id', 'png_path', 'plate', '
|
1776
|
+
png_list_df = dataframes['png_list'][['cell_id', 'png_path', 'plate', 'row_name', 'column_name']].copy()
|
1889
1777
|
png_list_df['cell_id'] = png_list_df['cell_id'].str[1:].astype(int)
|
1890
1778
|
png_list_df.rename(columns={'cell_id': 'object_label'}, inplace=True)
|
1891
1779
|
if 'cell' in dataframes:
|
1892
|
-
join_cols = ['object_label', 'plate', '
|
1780
|
+
join_cols = ['object_label', 'plate', 'row_name', 'column_name']
|
1893
1781
|
dataframes['cell'] = pd.merge(dataframes['cell'], png_list_df, on=join_cols, how='left')
|
1894
1782
|
else:
|
1895
1783
|
print("Cell table not found in database tables.")
|
@@ -2190,6 +2078,8 @@ def _read_db(db_loc, tables):
|
|
2190
2078
|
Returns:
|
2191
2079
|
- dfs (list): A list of pandas DataFrames, each containing the data from a table.
|
2192
2080
|
"""
|
2081
|
+
from .utils import rename_columns_in_db
|
2082
|
+
rename_columns_in_db(db_loc)
|
2193
2083
|
conn = sqlite3.connect(db_loc)
|
2194
2084
|
dfs = []
|
2195
2085
|
for table in tables:
|
@@ -2310,7 +2200,7 @@ def _read_and_merge_data(locs, tables, verbose=False, nuclei_limit=False, pathog
|
|
2310
2200
|
merged_df = merged_df.merge(pathogens_g_df, left_index=True, right_index=True)
|
2311
2201
|
|
2312
2202
|
#Add prc column (plate row column)
|
2313
|
-
metadata = metadata.assign(prc = lambda x: x['plate'] + '_' + x['
|
2203
|
+
metadata = metadata.assign(prc = lambda x: x['plate'] + '_' + x['row_name'] + '_' +x['column_name'])
|
2314
2204
|
|
2315
2205
|
#Count cells per well
|
2316
2206
|
cells_well = pd.DataFrame(metadata.groupby('prc')['object_label'].nunique())
|
@@ -2322,7 +2212,7 @@ def _read_and_merge_data(locs, tables, verbose=False, nuclei_limit=False, pathog
|
|
2322
2212
|
metadata.drop(columns=object_label_cols, inplace=True)
|
2323
2213
|
|
2324
2214
|
#Add prcfo column (plate row column field object)
|
2325
|
-
metadata = metadata.assign(prcfo = lambda x: x['plate'] + '_' + x['
|
2215
|
+
metadata = metadata.assign(prcfo = lambda x: x['plate'] + '_' + x['row_name'] + '_' +x['column_name']+ '_' +x['field']+ '_' +x['object_label'])
|
2326
2216
|
metadata.set_index('prcfo', inplace=True)
|
2327
2217
|
|
2328
2218
|
merged_df = metadata.merge(merged_df, left_index=True, right_index=True)
|
@@ -2517,6 +2407,10 @@ def _copy_missclassified(df):
|
|
2517
2407
|
return
|
2518
2408
|
|
2519
2409
|
def _read_db(db_loc, tables):
|
2410
|
+
|
2411
|
+
from .utils import rename_columns_in_db
|
2412
|
+
|
2413
|
+
rename_columns_in_db(db_loc)
|
2520
2414
|
conn = sqlite3.connect(db_loc) # Create a connection to the database
|
2521
2415
|
dfs = []
|
2522
2416
|
for table in tables:
|
@@ -2667,7 +2561,7 @@ def _read_and_merge_data(locs, tables, verbose=False, nuclei_limit=False, pathog
|
|
2667
2561
|
merged_df = merged_df.merge(pathogens_g_df, left_index=True, right_index=True)
|
2668
2562
|
|
2669
2563
|
#Add prc column (plate row column)
|
2670
|
-
metadata = metadata.assign(prc = lambda x: x['plate'] + '_' + x['
|
2564
|
+
metadata = metadata.assign(prc = lambda x: x['plate'] + '_' + x['row_name'] + '_' +x['column_name'])
|
2671
2565
|
|
2672
2566
|
#Count cells per well
|
2673
2567
|
cells_well = pd.DataFrame(metadata.groupby('prc')['object_label'].nunique())
|
@@ -2679,7 +2573,7 @@ def _read_and_merge_data(locs, tables, verbose=False, nuclei_limit=False, pathog
|
|
2679
2573
|
metadata.drop(columns=object_label_cols, inplace=True)
|
2680
2574
|
|
2681
2575
|
#Add prcfo column (plate row column field object)
|
2682
|
-
metadata = metadata.assign(prcfo = lambda x: x['plate'] + '_' + x['
|
2576
|
+
metadata = metadata.assign(prcfo = lambda x: x['plate'] + '_' + x['row_name'] + '_' +x['column_name']+ '_' +x['field']+ '_' +x['object_label'])
|
2683
2577
|
metadata.set_index('prcfo', inplace=True)
|
2684
2578
|
|
2685
2579
|
merged_df = metadata.merge(merged_df, left_index=True, right_index=True)
|
@@ -3030,8 +2924,7 @@ def generate_loaders(src, mode='train', image_size=224, batch_size=32, classes=[
|
|
3030
2924
|
def generate_training_dataset(settings):
|
3031
2925
|
|
3032
2926
|
# Function to filter png_list_df by prcfo present in df without merging
|
3033
|
-
def filter_png_list(db_path, settings):
|
3034
|
-
tables = ['cell', 'nucleus', 'pathogen', 'cytoplasm']
|
2927
|
+
def filter_png_list(db_path, settings, tables = ['cell', 'nucleus', 'pathogen', 'cytoplasm']):
|
3035
2928
|
df, _ = _read_and_merge_data(locs=[db_path],
|
3036
2929
|
tables=tables,
|
3037
2930
|
verbose=False,
|
@@ -3053,9 +2946,8 @@ def generate_training_dataset(settings):
|
|
3053
2946
|
return size
|
3054
2947
|
|
3055
2948
|
# Measurement-based selection logic
|
3056
|
-
def measurement_based_selection(settings, db_path):
|
2949
|
+
def measurement_based_selection(settings, db_path, tables = ['cell', 'nucleus', 'pathogen', 'cytoplasm']):
|
3057
2950
|
class_paths_ls = []
|
3058
|
-
tables = ['cell', 'nucleus', 'pathogen', 'cytoplasm']
|
3059
2951
|
df, _ = _read_and_merge_data(locs=[db_path],
|
3060
2952
|
tables=tables,
|
3061
2953
|
verbose=False,
|
@@ -3068,7 +2960,7 @@ def generate_training_dataset(settings):
|
|
3068
2960
|
treatment_loc=settings['class_metadata'])#, types=settings['metadata_type_by'])
|
3069
2961
|
print('length df 2', len(df))
|
3070
2962
|
|
3071
|
-
png_list_df = filter_png_list(db_path, settings)
|
2963
|
+
png_list_df = filter_png_list(db_path, settings, tables=settings['tables'])
|
3072
2964
|
|
3073
2965
|
if settings['custom_measurement']:
|
3074
2966
|
if isinstance(settings['custom_measurement'], list):
|
@@ -3101,8 +2993,8 @@ def generate_training_dataset(settings):
|
|
3101
2993
|
# Metadata-based selection logic
|
3102
2994
|
def metadata_based_selection(db_path, settings):
|
3103
2995
|
class_paths_ls = []
|
3104
|
-
df = filter_png_list(db_path, settings)
|
3105
|
-
|
2996
|
+
df = filter_png_list(db_path, settings, tables=settings['tables'])
|
2997
|
+
|
3106
2998
|
df['metadata_based_class'] = pd.NA
|
3107
2999
|
for i, class_ in enumerate(settings['classes']):
|
3108
3000
|
ls = settings['class_metadata'][i]
|
@@ -3126,10 +3018,10 @@ def generate_training_dataset(settings):
|
|
3126
3018
|
def annotation_based_selection(db_path, dst, settings):
|
3127
3019
|
class_paths_ls = training_dataset_from_annotation(db_path, dst, settings['annotation_column'], annotated_classes=settings['annotated_classes'])
|
3128
3020
|
|
3129
|
-
size = get_smallest_class_size(class_paths_ls, settings, 'annotation')
|
3130
|
-
for i, class_paths in enumerate(class_paths_ls):
|
3131
|
-
|
3132
|
-
|
3021
|
+
#size = get_smallest_class_size(class_paths_ls, settings, 'annotation')
|
3022
|
+
#for i, class_paths in enumerate(class_paths_ls):
|
3023
|
+
# if len(class_paths) > size:
|
3024
|
+
# class_paths_ls[i] = random.sample(class_paths, size)
|
3133
3025
|
|
3134
3026
|
return class_paths_ls
|
3135
3027
|
|
@@ -3137,6 +3029,13 @@ def generate_training_dataset(settings):
|
|
3137
3029
|
from .utils import get_paths_from_db, annotate_conditions, save_settings
|
3138
3030
|
from .settings import set_generate_training_dataset_defaults
|
3139
3031
|
|
3032
|
+
if 'nucleus' not in settings['tables']:
|
3033
|
+
settings['nuclei_limit'] = False
|
3034
|
+
|
3035
|
+
if 'pathogen' not in settings['tables']:
|
3036
|
+
settings['pathogen_limit'] = 0
|
3037
|
+
settings['uninfected'] = True
|
3038
|
+
|
3140
3039
|
# Set default settings and save
|
3141
3040
|
settings = set_generate_training_dataset_defaults(settings)
|
3142
3041
|
save_settings(settings, 'cv_dataset', show=True)
|
@@ -3145,6 +3044,7 @@ def generate_training_dataset(settings):
|
|
3145
3044
|
|
3146
3045
|
if isinstance(settings['src'], str):
|
3147
3046
|
src = [settings['src']]
|
3047
|
+
settings['src'] = src
|
3148
3048
|
|
3149
3049
|
for i, src in enumerate(settings['src']):
|
3150
3050
|
db_path = os.path.join(src, 'measurements', 'measurements.db')
|
@@ -3170,7 +3070,7 @@ def generate_training_dataset(settings):
|
|
3170
3070
|
class_paths_ls = metadata_based_selection(db_path, settings)
|
3171
3071
|
|
3172
3072
|
elif settings['dataset_mode'] == 'measurement':
|
3173
|
-
class_paths_ls = measurement_based_selection(settings, db_path)
|
3073
|
+
class_paths_ls = measurement_based_selection(settings, db_path, tables=settings['tables'])
|
3174
3074
|
|
3175
3075
|
if class_path_list is None:
|
3176
3076
|
class_path_list = [[] for _ in range(len(class_paths_ls))]
|
@@ -3180,22 +3080,72 @@ def generate_training_dataset(settings):
|
|
3180
3080
|
class_path_list[idx].extend(class_paths_ls[idx])
|
3181
3081
|
|
3182
3082
|
# Generate and return training and testing directories
|
3083
|
+
print('class_path_list',len(class_path_list))
|
3183
3084
|
train_class_dir, test_class_dir = generate_dataset_from_lists(dst, class_data=class_path_list, classes=settings['classes'], test_split=settings['test_split'])
|
3184
3085
|
|
3185
3086
|
return train_class_dir, test_class_dir
|
3186
3087
|
|
3187
3088
|
def training_dataset_from_annotation(db_path, dst, annotation_column='test', annotated_classes=(1, 2)):
|
3188
3089
|
all_paths = []
|
3189
|
-
|
3090
|
+
|
3190
3091
|
# Connect to the database and retrieve the image paths and annotations
|
3191
3092
|
print(f'Reading DataBase: {db_path}')
|
3192
3093
|
with sqlite3.connect(db_path) as conn:
|
3193
3094
|
cursor = conn.cursor()
|
3194
|
-
#
|
3195
|
-
|
3196
|
-
query
|
3197
|
-
|
3095
|
+
# Retrieve all paths and annotations from the database
|
3096
|
+
query = f"SELECT png_path, {annotation_column} FROM png_list"
|
3097
|
+
cursor.execute(query)
|
3098
|
+
|
3099
|
+
while True:
|
3100
|
+
rows = cursor.fetchmany(1000)
|
3101
|
+
if not rows:
|
3102
|
+
break
|
3103
|
+
for row in rows:
|
3104
|
+
all_paths.append(row)
|
3105
|
+
|
3106
|
+
print('Total paths retrieved:', len(all_paths))
|
3107
|
+
|
3108
|
+
# Filter paths based on annotated_classes
|
3109
|
+
class_paths = []
|
3110
|
+
for class_ in annotated_classes:
|
3111
|
+
class_paths_temp = [path for path, annotation in all_paths if annotation == class_]
|
3112
|
+
class_paths.append(class_paths_temp)
|
3113
|
+
print(f'Found {len(class_paths_temp)} images in class {class_}')
|
3114
|
+
|
3115
|
+
# If only one class is provided, create an alternative list by sampling paths from all_paths that are not in the annotated class
|
3116
|
+
if len(annotated_classes) == 1:
|
3117
|
+
target_class = annotated_classes[0]
|
3118
|
+
count_target_class = len(class_paths[0])
|
3119
|
+
print(f'Annotated class: {target_class} with {count_target_class} images')
|
3120
|
+
|
3121
|
+
# Filter all_paths to exclude paths that belong to the target class
|
3122
|
+
alt_class_paths = [path for path, annotation in all_paths if annotation != target_class]
|
3123
|
+
print('Alternative paths available:', len(alt_class_paths))
|
3124
|
+
|
3125
|
+
# Randomly sample an equal number of images for the second class
|
3126
|
+
sampled_alt_class_paths = random.sample(alt_class_paths, min(count_target_class, len(alt_class_paths)))
|
3127
|
+
print(f'Sampled {len(sampled_alt_class_paths)} alternative images for balancing')
|
3128
|
+
|
3129
|
+
# Append this list as the second class
|
3130
|
+
class_paths.append(sampled_alt_class_paths)
|
3131
|
+
|
3132
|
+
print(f'Generated a list of lists from annotation of {len(class_paths)} classes')
|
3133
|
+
for i, ls in enumerate(class_paths):
|
3134
|
+
print(f'Class {i}: {len(ls)} images')
|
3135
|
+
|
3136
|
+
return class_paths
|
3198
3137
|
|
3138
|
+
def training_dataset_from_annotation_v2(db_path, dst, annotation_column='test', annotated_classes=(1, 2)):
|
3139
|
+
all_paths = []
|
3140
|
+
|
3141
|
+
# Connect to the database and retrieve the image paths and annotations
|
3142
|
+
print(f'Reading DataBase: {db_path}')
|
3143
|
+
with sqlite3.connect(db_path) as conn:
|
3144
|
+
cursor = conn.cursor()
|
3145
|
+
# Retrieve all paths and annotations from the database
|
3146
|
+
query = f"SELECT png_path, {annotation_column} FROM png_list"
|
3147
|
+
cursor.execute(query)
|
3148
|
+
|
3199
3149
|
while True:
|
3200
3150
|
rows = cursor.fetchmany(1000)
|
3201
3151
|
if not rows:
|
@@ -3203,13 +3153,36 @@ def training_dataset_from_annotation(db_path, dst, annotation_column='test', ann
|
|
3203
3153
|
for row in rows:
|
3204
3154
|
all_paths.append(row)
|
3205
3155
|
|
3206
|
-
|
3156
|
+
print('Total paths retrieved:', len(all_paths))
|
3157
|
+
|
3158
|
+
# Filter paths based on annotated_classes
|
3207
3159
|
class_paths = []
|
3208
3160
|
for class_ in annotated_classes:
|
3209
3161
|
class_paths_temp = [path for path, annotation in all_paths if annotation == class_]
|
3210
3162
|
class_paths.append(class_paths_temp)
|
3163
|
+
print(f'Found {len(class_paths_temp)} images in class {class_}')
|
3164
|
+
|
3165
|
+
# If only one class is provided, create an alternative list by sampling paths from all_paths that are not in the annotated class
|
3166
|
+
if len(annotated_classes) == 1:
|
3167
|
+
target_class = annotated_classes[0]
|
3168
|
+
count_target_class = len(class_paths[0])
|
3169
|
+
print(f'Annotated class: {target_class} with {count_target_class} images')
|
3170
|
+
|
3171
|
+
# Filter all_paths to exclude paths that belong to the target class
|
3172
|
+
alt_class_paths = [path for path, annotation in all_paths if annotation != target_class]
|
3173
|
+
print('Alternative paths available:', len(alt_class_paths))
|
3174
|
+
|
3175
|
+
# Randomly sample an equal number of images for the second class
|
3176
|
+
sampled_alt_class_paths = random.sample(alt_class_paths, min(count_target_class, len(alt_class_paths)))
|
3177
|
+
print(f'Sampled {len(sampled_alt_class_paths)} alternative images for balancing')
|
3178
|
+
|
3179
|
+
# Append this list as the second class
|
3180
|
+
class_paths.append(sampled_alt_class_paths)
|
3211
3181
|
|
3212
3182
|
print(f'Generated a list of lists from annotation of {len(class_paths)} classes')
|
3183
|
+
for i, ls in enumerate(class_paths):
|
3184
|
+
print(f'Class {i}: {len(ls)} images')
|
3185
|
+
|
3213
3186
|
return class_paths
|
3214
3187
|
|
3215
3188
|
def generate_dataset_from_lists(dst, class_data, classes, test_split=0.1):
|
@@ -3228,8 +3201,9 @@ def generate_dataset_from_lists(dst, class_data, classes, test_split=0.1):
|
|
3228
3201
|
test_class_dir = os.path.join(dst, f'test/{cls}')
|
3229
3202
|
os.makedirs(train_class_dir, exist_ok=True)
|
3230
3203
|
os.makedirs(test_class_dir, exist_ok=True)
|
3231
|
-
|
3204
|
+
|
3232
3205
|
# Split the data
|
3206
|
+
print('data',len(data), test_split)
|
3233
3207
|
train_data, test_data = train_test_split(data, test_size=test_split, shuffle=True, random_state=42)
|
3234
3208
|
|
3235
3209
|
# Copy train files
|
spacr/measure.py
CHANGED
@@ -16,6 +16,7 @@ from skimage.util import img_as_bool
|
|
16
16
|
import matplotlib.pyplot as plt
|
17
17
|
from math import ceil, sqrt
|
18
18
|
|
19
|
+
|
19
20
|
def get_components(cell_mask, nucleus_mask, pathogen_mask):
|
20
21
|
"""
|
21
22
|
Get the components (nucleus and pathogens) for each cell in the given masks.
|
@@ -761,12 +762,10 @@ def _measure_crop_core(index, time_ls, file, settings):
|
|
761
762
|
if settings['cytoplasm_min_size'] is not None and settings['cytoplasm_min_size'] != 0:
|
762
763
|
cytoplasm_mask = _filter_object(cytoplasm_mask, settings['cytoplasm_min_size'])
|
763
764
|
|
764
|
-
if settings['cell_mask_dim'] is not None:
|
765
|
+
if settings['cell_mask_dim'] is not None and settings['nucleus_mask_dim'] is not None and settings['pathogen_mask_dim'] is not None:
|
765
766
|
cell_mask, nucleus_mask, pathogen_mask, cytoplasm_mask = _exclude_objects(cell_mask, nucleus_mask, pathogen_mask, cytoplasm_mask, uninfected=settings['uninfected'])
|
766
|
-
|
767
|
-
# Update data with the new masks
|
768
|
-
if settings['cell_mask_dim'] is not None:
|
769
767
|
data[:, :, settings['cell_mask_dim']] = cell_mask.astype(data_type)
|
768
|
+
|
770
769
|
if settings['nucleus_mask_dim'] is not None:
|
771
770
|
data[:, :, settings['nucleus_mask_dim']] = nucleus_mask.astype(data_type)
|
772
771
|
if settings['pathogen_mask_dim'] is not None:
|
@@ -779,7 +778,6 @@ def _measure_crop_core(index, time_ls, file, settings):
|
|
779
778
|
figs[f'{file_name}__after_filtration'] = fig
|
780
779
|
|
781
780
|
if settings['save_measurements']:
|
782
|
-
|
783
781
|
cell_df, nucleus_df, pathogen_df, cytoplasm_df = _morphological_measurements(cell_mask, nucleus_mask, pathogen_mask, cytoplasm_mask, settings)
|
784
782
|
|
785
783
|
#if settings['skeleton']:
|
@@ -789,7 +787,6 @@ def _measure_crop_core(index, time_ls, file, settings):
|
|
789
787
|
cell_intensity_df, nucleus_intensity_df, pathogen_intensity_df, cytoplasm_intensity_df = _intensity_measurements(cell_mask, nucleus_mask, pathogen_mask, cytoplasm_mask, channel_arrays, settings, sizes=[1, 2, 3, 4, 5], periphery=True, outside=True)
|
790
788
|
if settings['cell_mask_dim'] is not None:
|
791
789
|
cell_merged_df = _merge_and_save_to_database(cell_df, cell_intensity_df, 'cell', source_folder, file_name, settings['experiment'], settings['timelapse'])
|
792
|
-
|
793
790
|
if settings['nucleus_mask_dim'] is not None:
|
794
791
|
nucleus_merged_df = _merge_and_save_to_database(nucleus_df, nucleus_intensity_df, 'nucleus', source_folder, file_name, settings['experiment'], settings['timelapse'])
|
795
792
|
|
@@ -800,7 +797,6 @@ def _measure_crop_core(index, time_ls, file, settings):
|
|
800
797
|
cytoplasm_merged_df = _merge_and_save_to_database(cytoplasm_df, cytoplasm_intensity_df, 'cytoplasm', source_folder, file_name, settings['experiment'], settings['timelapse'])
|
801
798
|
|
802
799
|
if settings['save_png'] or settings['save_arrays'] or settings['plot']:
|
803
|
-
|
804
800
|
if isinstance(settings['dialate_pngs'], bool):
|
805
801
|
dialate_pngs = [settings['dialate_pngs'], settings['dialate_pngs'], settings['dialate_pngs']]
|
806
802
|
if isinstance(settings['dialate_pngs'], list):
|
@@ -825,13 +821,15 @@ def _measure_crop_core(index, time_ls, file, settings):
|
|
825
821
|
|
826
822
|
if len(crop_ls) != len(size_ls):
|
827
823
|
print(f"Setting: size_ls: {settings['png_size']} should be a list of integers, or a list of lists of integers if crop_ls: {settings['crop_mode']} has multiple elements")
|
828
|
-
|
824
|
+
|
829
825
|
for crop_idx, crop_mode in enumerate(crop_ls):
|
830
826
|
width, height = size_ls[crop_idx]
|
827
|
+
|
831
828
|
if crop_mode == 'cell':
|
832
829
|
crop_mask = cell_mask.copy()
|
833
830
|
dialate_png = dialate_pngs[crop_idx]
|
834
831
|
dialate_png_ratio = dialate_png_ratios[crop_idx]
|
832
|
+
|
835
833
|
elif crop_mode == 'nucleus':
|
836
834
|
crop_mask = nucleus_mask.copy()
|
837
835
|
dialate_png = dialate_pngs[crop_idx]
|
@@ -852,7 +850,7 @@ def _measure_crop_core(index, time_ls, file, settings):
|
|
852
850
|
|
853
851
|
for _id in objects_in_image:
|
854
852
|
|
855
|
-
region = (crop_mask == _id)
|
853
|
+
region = (crop_mask == _id)
|
856
854
|
|
857
855
|
# Use the boolean mask to filter the cell_mask and then find unique IDs
|
858
856
|
region_cell_ids = np.atleast_1d(np.unique(cell_mask[region]))
|
@@ -947,7 +945,7 @@ def measure_crop(settings):
|
|
947
945
|
|
948
946
|
from .io import _save_settings_to_db
|
949
947
|
from .timelapse import _timelapse_masks_to_gif
|
950
|
-
from .utils import measure_test_mode, print_progress
|
948
|
+
from .utils import measure_test_mode, print_progress, save_settings
|
951
949
|
from .settings import get_measure_crop_settings
|
952
950
|
|
953
951
|
if not isinstance(settings['src'], (str, list)):
|
@@ -1032,9 +1030,10 @@ def measure_crop(settings):
|
|
1032
1030
|
settings['crop_mode'] = [settings['crop_mode']]
|
1033
1031
|
settings['crop_mode'] = [str(crop_mode) for crop_mode in settings['crop_mode']]
|
1034
1032
|
print(f"Converted crop_mode to list: {settings['crop_mode']}")
|
1035
|
-
return
|
1036
1033
|
|
1037
1034
|
_save_settings_to_db(settings)
|
1035
|
+
#save_settings(settings, name='measure_crop', show=True)
|
1036
|
+
|
1038
1037
|
files = [f for f in os.listdir(settings['src']) if f.endswith('.npy')]
|
1039
1038
|
n_jobs = settings['n_jobs']
|
1040
1039
|
print(f'using {n_jobs} cpu cores')
|