spacr 0.3.52__py3-none-any.whl → 0.3.55__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/gui_elements.py CHANGED
@@ -706,7 +706,7 @@ class spacrProgressBar(ttk.Progressbar):
706
706
 
707
707
  def set_label_position(self):
708
708
  if self.label and self.progress_label:
709
- row_info = self.grid_info().get('row', 0)
709
+ row_info = self.grid_info().get('row_name', 0)
710
710
  col_info = self.grid_info().get('column', 0)
711
711
  col_span = self.grid_info().get('columnspan', 1)
712
712
  self.progress_label.grid(row=row_info + 1, column=col_info, columnspan=col_span, pady=5, padx=5, sticky='ew')
spacr/gui_utils.py CHANGED
@@ -106,32 +106,6 @@ def parse_list(value):
106
106
  except (ValueError, SyntaxError) as e:
107
107
  raise ValueError(f"Invalid format for list: {value}. Error: {e}")
108
108
 
109
- def parse_list_v1(value):
110
- """
111
- Parses a string representation of a list and returns the parsed list.
112
-
113
- Args:
114
- value (str): The string representation of the list.
115
-
116
- Returns:
117
- list: The parsed list, which can contain integers, floats, or strings.
118
-
119
- Raises:
120
- ValueError: If the input value is not a valid list format or contains mixed types or unsupported types.
121
- """
122
- try:
123
- parsed_value = ast.literal_eval(value)
124
- if isinstance(parsed_value, list):
125
- # Check if all elements are homogeneous (either all int, float, or str)
126
- if all(isinstance(item, (int, float, str)) for item in parsed_value):
127
- return parsed_value
128
- else:
129
- raise ValueError("List contains mixed types or unsupported types")
130
- else:
131
- raise ValueError(f"Expected a list but got {type(parsed_value).__name__}")
132
- except (ValueError, SyntaxError) as e:
133
- raise ValueError(f"Invalid format for list: {value}. Error: {e}")
134
-
135
109
  # Usage example in your create_input_field function
136
110
  def create_input_field(frame, label_text, row, var_type='entry', options=None, default_value=None):
137
111
  """
@@ -696,91 +670,6 @@ def ensure_after_tasks(frame):
696
670
  if not hasattr(frame, 'after_tasks'):
697
671
  frame.after_tasks = []
698
672
 
699
- def display_gif_in_plot_frame_v1(gif_path, parent_frame):
700
- """Display and zoom a GIF to fill the entire parent_frame, maintaining aspect ratio, with lazy resizing and caching."""
701
- # Clear parent_frame if it contains any previous widgets
702
- for widget in parent_frame.winfo_children():
703
- widget.destroy()
704
-
705
- # Load the GIF
706
- gif = Image.open(gif_path)
707
-
708
- # Get the aspect ratio of the GIF
709
- gif_width, gif_height = gif.size
710
- gif_aspect_ratio = gif_width / gif_height
711
-
712
- # Create a label to display the GIF and configure it to fill the parent_frame
713
- label = tk.Label(parent_frame, bg="black")
714
- label.grid(row=0, column=0, sticky="nsew") # Expands in all directions (north, south, east, west)
715
-
716
- # Configure parent_frame to stretch the label to fill available space
717
- parent_frame.grid_rowconfigure(0, weight=1)
718
- parent_frame.grid_columnconfigure(0, weight=1)
719
-
720
- # Cache for storing resized frames (lazily filled)
721
- resized_frames_cache = {}
722
-
723
- # Last frame dimensions
724
- last_frame_width = 0
725
- last_frame_height = 0
726
-
727
- def resize_and_crop_frame(frame_idx, frame_width, frame_height):
728
- """Resize and crop the current frame of the GIF to fit the parent_frame while maintaining the aspect ratio."""
729
- # If the frame is already cached at the current size, return it
730
- if (frame_idx, frame_width, frame_height) in resized_frames_cache:
731
- return resized_frames_cache[(frame_idx, frame_width, frame_height)]
732
-
733
- # Calculate the scaling factor to zoom in on the GIF
734
- scale_factor = max(frame_width / gif_width, frame_height / gif_height)
735
-
736
- # Calculate new dimensions while maintaining the aspect ratio
737
- new_width = int(gif_width * scale_factor)
738
- new_height = int(gif_height * scale_factor)
739
-
740
- # Resize the GIF to fit the frame
741
- gif.seek(frame_idx)
742
- resized_gif = gif.copy().resize((new_width, new_height), Image.Resampling.LANCZOS)
743
-
744
- # Calculate the cropping box to center the resized GIF in the frame
745
- crop_left = (new_width - frame_width) // 2
746
- crop_top = (new_height - frame_height) // 2
747
- crop_right = crop_left + frame_width
748
- crop_bottom = crop_top + frame_height
749
-
750
- # Crop the resized GIF to exactly fit the frame
751
- cropped_gif = resized_gif.crop((crop_left, crop_top, crop_right, crop_bottom))
752
-
753
- # Convert the cropped frame to a Tkinter-compatible format
754
- frame_image = ImageTk.PhotoImage(cropped_gif)
755
-
756
- # Cache the resized frame
757
- resized_frames_cache[(frame_idx, frame_width, frame_height)] = frame_image
758
-
759
- return frame_image
760
-
761
- def update_frame(frame_idx):
762
- """Update the GIF frame using lazy resizing and caching."""
763
- # Get the current size of the parent_frame
764
- frame_width = parent_frame.winfo_width()
765
- frame_height = parent_frame.winfo_height()
766
-
767
- # Only resize if the frame size has changed
768
- nonlocal last_frame_width, last_frame_height
769
- if frame_width != last_frame_width or frame_height != last_frame_height:
770
- last_frame_width, last_frame_height = frame_width, frame_height
771
-
772
- # Get the resized and cropped frame image
773
- frame_image = resize_and_crop_frame(frame_idx, frame_width, frame_height)
774
- label.config(image=frame_image)
775
- label.image = frame_image # Keep a reference to avoid garbage collection
776
-
777
- # Move to the next frame, or loop back to the beginning
778
- next_frame_idx = (frame_idx + 1) % gif.n_frames
779
- parent_frame.after(gif.info['duration'], update_frame, next_frame_idx)
780
-
781
- # Start the GIF animation from frame 0
782
- update_frame(0)
783
-
784
673
  def display_gif_in_plot_frame(gif_path, parent_frame):
785
674
  """Display and zoom a GIF to fill the entire parent_frame, maintaining aspect ratio, with lazy resizing and caching."""
786
675
  # Clear parent_frame if it contains any previous widgets
spacr/io.py CHANGED
@@ -292,121 +292,6 @@ def _load_normalized_images_and_labels(image_files, label_files, channels=None,
292
292
 
293
293
  return normalized_images, labels, image_names, label_names, orig_dims
294
294
 
295
- def _load_normalized_images_and_labels_v1(image_files, label_files, channels=None, percentiles=None, invert=False, visualize=False, remove_background=False, background=0, Signal_to_noise=10, target_height=None, target_width=None):
296
-
297
- from .plot import normalize_and_visualize, plot_resize
298
- from .utils import invert_image, apply_mask
299
- from skimage.transform import resize as resizescikit
300
-
301
- if isinstance(percentiles, list):
302
- if len(percentiles) !=2:
303
- percentiles = None
304
- if not percentiles[0] is int:
305
- percentiles = None
306
- if not percentiles[1] is int:
307
- percentiles = None
308
-
309
- signal_thresholds = background * Signal_to_noise
310
- lower_percentile = 2
311
-
312
- images = []
313
- labels = []
314
- orig_dims = []
315
-
316
- num_channels = 4
317
- percentiles_1 = [[] for _ in range(num_channels)]
318
- percentiles_99 = [[] for _ in range(num_channels)]
319
-
320
- image_names = [os.path.basename(f) for f in image_files]
321
- image_dir = os.path.dirname(image_files[0])
322
-
323
- if label_files is not None:
324
- label_names = [os.path.basename(f) for f in label_files]
325
- label_dir = os.path.dirname(label_files[0])
326
-
327
- # Load, normalize, and resize images
328
- for i, img_file in enumerate(image_files):
329
- image = cellpose.io.imread(img_file)
330
- orig_dims.append((image.shape[0], image.shape[1]))
331
- if invert:
332
- image = invert_image(image)
333
-
334
- # If specific channels are specified, select them
335
- if channels is not None and image.ndim == 3:
336
- image = image[..., channels]
337
-
338
- if remove_background:
339
- image[image < background] = 0
340
-
341
- if image.ndim < 3:
342
- image = np.expand_dims(image, axis=-1)
343
-
344
- if percentiles is None:
345
- for c in range(image.shape[-1]):
346
- p1 = np.percentile(image[..., c], lower_percentile)
347
- percentiles_1[c].append(p1)
348
- for percentile in [98, 99, 99.9, 99.99, 99.999]:
349
- p = np.percentile(image[..., c], percentile)
350
- if p > signal_thresholds:
351
- percentiles_99[c].append(p)
352
- break
353
-
354
- # Resize image
355
- if target_height is not None and target_width is not None:
356
- if image.ndim == 2:
357
- image_shape = (target_height, target_width)
358
- elif image.ndim == 3:
359
- image_shape = (target_height, target_width, image.shape[-1])
360
-
361
- image = resizescikit(image, image_shape, preserve_range=True, anti_aliasing=True).astype(image.dtype)
362
-
363
- images.append(image)
364
-
365
- if percentiles is None:
366
- # Calculate average percentiles for normalization
367
- avg_p1 = [np.mean(p) for p in percentiles_1]
368
- avg_p99 = [np.mean(p) if len(p) > 0 else np.mean(percentiles_1[i]) for i, p in enumerate(percentiles_99)]
369
-
370
- print(f'Average 1st percentiles: {avg_p1}, Average 99th percentiles: {avg_p99}')
371
-
372
- normalized_images = []
373
- for image in images:
374
- normalized_image = np.zeros_like(image, dtype=np.float32)
375
- for c in range(image.shape[-1]):
376
- normalized_image[..., c] = rescale_intensity(image[..., c], in_range=(avg_p1[c], avg_p99[c]), out_range=(0, 1))
377
- normalized_images.append(normalized_image)
378
- if visualize:
379
- normalize_and_visualize(image, normalized_image, title=f"Channel {c+1} Normalized")
380
- else:
381
- normalized_images = []
382
- for image in images:
383
- normalized_image = np.zeros_like(image, dtype=np.float32)
384
- for c in range(image.shape[-1]):
385
- low_p = np.percentile(image[..., c], percentiles[0])
386
- high_p = np.percentile(image[..., c], percentiles[1])
387
- normalized_image[..., c] = rescale_intensity(image[..., c], in_range=(low_p, high_p), out_range=(0, 1))
388
- normalized_images.append(normalized_image)
389
- if visualize:
390
- normalize_and_visualize(image, normalized_image, title=f"Channel {c+1} Normalized")
391
-
392
- if label_files is not None:
393
- for lbl_file in label_files:
394
- label = cellpose.io.imread(lbl_file)
395
- # Resize label
396
- if target_height is not None and target_width is not None:
397
- label = resizescikit(label, (target_height, target_width), order=0, preserve_range=True, anti_aliasing=False).astype(label.dtype)
398
- labels.append(label)
399
- else:
400
- label_names = []
401
- label_dir = None
402
-
403
- print(f'Loaded and normalized {len(normalized_images)} images and {len(labels)} labels from {image_dir} and {label_dir}')
404
-
405
- if visualize and images and labels:
406
- plot_resize(images, normalized_images, labels, labels)
407
-
408
- return normalized_images, labels, image_names, label_names, orig_dims
409
-
410
295
  class CombineLoaders:
411
296
 
412
297
  """
@@ -1875,6 +1760,9 @@ def _read_and_join_tables(db_path, table_names=['cell', 'cytoplasm', 'nucleus',
1875
1760
  Returns:
1876
1761
  pandas.DataFrame: The joined DataFrame containing the data from the specified tables, or None if an error occurs.
1877
1762
  """
1763
+ from .utils import rename_columns_in_db
1764
+ rename_columns_in_db(db_path)
1765
+
1878
1766
  conn = sqlite3.connect(db_path)
1879
1767
  dataframes = {}
1880
1768
  for table_name in table_names:
@@ -1885,11 +1773,11 @@ def _read_and_join_tables(db_path, table_names=['cell', 'cytoplasm', 'nucleus',
1885
1773
  print(e)
1886
1774
  conn.close()
1887
1775
  if 'png_list' in dataframes:
1888
- png_list_df = dataframes['png_list'][['cell_id', 'png_path', 'plate', 'row', 'col']].copy()
1776
+ png_list_df = dataframes['png_list'][['cell_id', 'png_path', 'plate', 'row_name', 'column_name']].copy()
1889
1777
  png_list_df['cell_id'] = png_list_df['cell_id'].str[1:].astype(int)
1890
1778
  png_list_df.rename(columns={'cell_id': 'object_label'}, inplace=True)
1891
1779
  if 'cell' in dataframes:
1892
- join_cols = ['object_label', 'plate', 'row', 'col']
1780
+ join_cols = ['object_label', 'plate', 'row_name', 'column_name']
1893
1781
  dataframes['cell'] = pd.merge(dataframes['cell'], png_list_df, on=join_cols, how='left')
1894
1782
  else:
1895
1783
  print("Cell table not found in database tables.")
@@ -2190,6 +2078,8 @@ def _read_db(db_loc, tables):
2190
2078
  Returns:
2191
2079
  - dfs (list): A list of pandas DataFrames, each containing the data from a table.
2192
2080
  """
2081
+ from .utils import rename_columns_in_db
2082
+ rename_columns_in_db(db_loc)
2193
2083
  conn = sqlite3.connect(db_loc)
2194
2084
  dfs = []
2195
2085
  for table in tables:
@@ -2310,7 +2200,7 @@ def _read_and_merge_data(locs, tables, verbose=False, nuclei_limit=False, pathog
2310
2200
  merged_df = merged_df.merge(pathogens_g_df, left_index=True, right_index=True)
2311
2201
 
2312
2202
  #Add prc column (plate row column)
2313
- metadata = metadata.assign(prc = lambda x: x['plate'] + '_' + x['row'] + '_' +x['col'])
2203
+ metadata = metadata.assign(prc = lambda x: x['plate'] + '_' + x['row_name'] + '_' +x['column_name'])
2314
2204
 
2315
2205
  #Count cells per well
2316
2206
  cells_well = pd.DataFrame(metadata.groupby('prc')['object_label'].nunique())
@@ -2322,7 +2212,7 @@ def _read_and_merge_data(locs, tables, verbose=False, nuclei_limit=False, pathog
2322
2212
  metadata.drop(columns=object_label_cols, inplace=True)
2323
2213
 
2324
2214
  #Add prcfo column (plate row column field object)
2325
- metadata = metadata.assign(prcfo = lambda x: x['plate'] + '_' + x['row'] + '_' +x['col']+ '_' +x['field']+ '_' +x['object_label'])
2215
+ metadata = metadata.assign(prcfo = lambda x: x['plate'] + '_' + x['row_name'] + '_' +x['column_name']+ '_' +x['field']+ '_' +x['object_label'])
2326
2216
  metadata.set_index('prcfo', inplace=True)
2327
2217
 
2328
2218
  merged_df = metadata.merge(merged_df, left_index=True, right_index=True)
@@ -2517,6 +2407,10 @@ def _copy_missclassified(df):
2517
2407
  return
2518
2408
 
2519
2409
  def _read_db(db_loc, tables):
2410
+
2411
+ from .utils import rename_columns_in_db
2412
+
2413
+ rename_columns_in_db(db_loc)
2520
2414
  conn = sqlite3.connect(db_loc) # Create a connection to the database
2521
2415
  dfs = []
2522
2416
  for table in tables:
@@ -2667,7 +2561,7 @@ def _read_and_merge_data(locs, tables, verbose=False, nuclei_limit=False, pathog
2667
2561
  merged_df = merged_df.merge(pathogens_g_df, left_index=True, right_index=True)
2668
2562
 
2669
2563
  #Add prc column (plate row column)
2670
- metadata = metadata.assign(prc = lambda x: x['plate'] + '_' + x['row'] + '_' +x['col'])
2564
+ metadata = metadata.assign(prc = lambda x: x['plate'] + '_' + x['row_name'] + '_' +x['column_name'])
2671
2565
 
2672
2566
  #Count cells per well
2673
2567
  cells_well = pd.DataFrame(metadata.groupby('prc')['object_label'].nunique())
@@ -2679,7 +2573,7 @@ def _read_and_merge_data(locs, tables, verbose=False, nuclei_limit=False, pathog
2679
2573
  metadata.drop(columns=object_label_cols, inplace=True)
2680
2574
 
2681
2575
  #Add prcfo column (plate row column field object)
2682
- metadata = metadata.assign(prcfo = lambda x: x['plate'] + '_' + x['row'] + '_' +x['col']+ '_' +x['field']+ '_' +x['object_label'])
2576
+ metadata = metadata.assign(prcfo = lambda x: x['plate'] + '_' + x['row_name'] + '_' +x['column_name']+ '_' +x['field']+ '_' +x['object_label'])
2683
2577
  metadata.set_index('prcfo', inplace=True)
2684
2578
 
2685
2579
  merged_df = metadata.merge(merged_df, left_index=True, right_index=True)
@@ -3030,8 +2924,7 @@ def generate_loaders(src, mode='train', image_size=224, batch_size=32, classes=[
3030
2924
  def generate_training_dataset(settings):
3031
2925
 
3032
2926
  # Function to filter png_list_df by prcfo present in df without merging
3033
- def filter_png_list(db_path, settings):
3034
- tables = ['cell', 'nucleus', 'pathogen', 'cytoplasm']
2927
+ def filter_png_list(db_path, settings, tables = ['cell', 'nucleus', 'pathogen', 'cytoplasm']):
3035
2928
  df, _ = _read_and_merge_data(locs=[db_path],
3036
2929
  tables=tables,
3037
2930
  verbose=False,
@@ -3053,9 +2946,8 @@ def generate_training_dataset(settings):
3053
2946
  return size
3054
2947
 
3055
2948
  # Measurement-based selection logic
3056
- def measurement_based_selection(settings, db_path):
2949
+ def measurement_based_selection(settings, db_path, tables = ['cell', 'nucleus', 'pathogen', 'cytoplasm']):
3057
2950
  class_paths_ls = []
3058
- tables = ['cell', 'nucleus', 'pathogen', 'cytoplasm']
3059
2951
  df, _ = _read_and_merge_data(locs=[db_path],
3060
2952
  tables=tables,
3061
2953
  verbose=False,
@@ -3068,7 +2960,7 @@ def generate_training_dataset(settings):
3068
2960
  treatment_loc=settings['class_metadata'])#, types=settings['metadata_type_by'])
3069
2961
  print('length df 2', len(df))
3070
2962
 
3071
- png_list_df = filter_png_list(db_path, settings)
2963
+ png_list_df = filter_png_list(db_path, settings, tables=settings['tables'])
3072
2964
 
3073
2965
  if settings['custom_measurement']:
3074
2966
  if isinstance(settings['custom_measurement'], list):
@@ -3101,8 +2993,8 @@ def generate_training_dataset(settings):
3101
2993
  # Metadata-based selection logic
3102
2994
  def metadata_based_selection(db_path, settings):
3103
2995
  class_paths_ls = []
3104
- df = filter_png_list(db_path, settings)
3105
-
2996
+ df = filter_png_list(db_path, settings, tables=settings['tables'])
2997
+
3106
2998
  df['metadata_based_class'] = pd.NA
3107
2999
  for i, class_ in enumerate(settings['classes']):
3108
3000
  ls = settings['class_metadata'][i]
@@ -3126,10 +3018,10 @@ def generate_training_dataset(settings):
3126
3018
  def annotation_based_selection(db_path, dst, settings):
3127
3019
  class_paths_ls = training_dataset_from_annotation(db_path, dst, settings['annotation_column'], annotated_classes=settings['annotated_classes'])
3128
3020
 
3129
- size = get_smallest_class_size(class_paths_ls, settings, 'annotation')
3130
- for i, class_paths in enumerate(class_paths_ls):
3131
- if len(class_paths) > size:
3132
- class_paths_ls[i] = random.sample(class_paths, size)
3021
+ #size = get_smallest_class_size(class_paths_ls, settings, 'annotation')
3022
+ #for i, class_paths in enumerate(class_paths_ls):
3023
+ # if len(class_paths) > size:
3024
+ # class_paths_ls[i] = random.sample(class_paths, size)
3133
3025
 
3134
3026
  return class_paths_ls
3135
3027
 
@@ -3137,6 +3029,13 @@ def generate_training_dataset(settings):
3137
3029
  from .utils import get_paths_from_db, annotate_conditions, save_settings
3138
3030
  from .settings import set_generate_training_dataset_defaults
3139
3031
 
3032
+ if 'nucleus' not in settings['tables']:
3033
+ settings['nuclei_limit'] = False
3034
+
3035
+ if 'pathogen' not in settings['tables']:
3036
+ settings['pathogen_limit'] = 0
3037
+ settings['uninfected'] = True
3038
+
3140
3039
  # Set default settings and save
3141
3040
  settings = set_generate_training_dataset_defaults(settings)
3142
3041
  save_settings(settings, 'cv_dataset', show=True)
@@ -3145,6 +3044,7 @@ def generate_training_dataset(settings):
3145
3044
 
3146
3045
  if isinstance(settings['src'], str):
3147
3046
  src = [settings['src']]
3047
+ settings['src'] = src
3148
3048
 
3149
3049
  for i, src in enumerate(settings['src']):
3150
3050
  db_path = os.path.join(src, 'measurements', 'measurements.db')
@@ -3170,7 +3070,7 @@ def generate_training_dataset(settings):
3170
3070
  class_paths_ls = metadata_based_selection(db_path, settings)
3171
3071
 
3172
3072
  elif settings['dataset_mode'] == 'measurement':
3173
- class_paths_ls = measurement_based_selection(settings, db_path)
3073
+ class_paths_ls = measurement_based_selection(settings, db_path, tables=settings['tables'])
3174
3074
 
3175
3075
  if class_path_list is None:
3176
3076
  class_path_list = [[] for _ in range(len(class_paths_ls))]
@@ -3180,22 +3080,72 @@ def generate_training_dataset(settings):
3180
3080
  class_path_list[idx].extend(class_paths_ls[idx])
3181
3081
 
3182
3082
  # Generate and return training and testing directories
3083
+ print('class_path_list',len(class_path_list))
3183
3084
  train_class_dir, test_class_dir = generate_dataset_from_lists(dst, class_data=class_path_list, classes=settings['classes'], test_split=settings['test_split'])
3184
3085
 
3185
3086
  return train_class_dir, test_class_dir
3186
3087
 
3187
3088
  def training_dataset_from_annotation(db_path, dst, annotation_column='test', annotated_classes=(1, 2)):
3188
3089
  all_paths = []
3189
-
3090
+
3190
3091
  # Connect to the database and retrieve the image paths and annotations
3191
3092
  print(f'Reading DataBase: {db_path}')
3192
3093
  with sqlite3.connect(db_path) as conn:
3193
3094
  cursor = conn.cursor()
3194
- # Prepare the query with parameterized placeholders for annotated_classes
3195
- placeholders = ','.join('?' * len(annotated_classes))
3196
- query = f"SELECT png_path, {annotation_column} FROM png_list WHERE {annotation_column} IN ({placeholders})"
3197
- cursor.execute(query, annotated_classes)
3095
+ # Retrieve all paths and annotations from the database
3096
+ query = f"SELECT png_path, {annotation_column} FROM png_list"
3097
+ cursor.execute(query)
3098
+
3099
+ while True:
3100
+ rows = cursor.fetchmany(1000)
3101
+ if not rows:
3102
+ break
3103
+ for row in rows:
3104
+ all_paths.append(row)
3105
+
3106
+ print('Total paths retrieved:', len(all_paths))
3107
+
3108
+ # Filter paths based on annotated_classes
3109
+ class_paths = []
3110
+ for class_ in annotated_classes:
3111
+ class_paths_temp = [path for path, annotation in all_paths if annotation == class_]
3112
+ class_paths.append(class_paths_temp)
3113
+ print(f'Found {len(class_paths_temp)} images in class {class_}')
3114
+
3115
+ # If only one class is provided, create an alternative list by sampling paths from all_paths that are not in the annotated class
3116
+ if len(annotated_classes) == 1:
3117
+ target_class = annotated_classes[0]
3118
+ count_target_class = len(class_paths[0])
3119
+ print(f'Annotated class: {target_class} with {count_target_class} images')
3120
+
3121
+ # Filter all_paths to exclude paths that belong to the target class
3122
+ alt_class_paths = [path for path, annotation in all_paths if annotation != target_class]
3123
+ print('Alternative paths available:', len(alt_class_paths))
3124
+
3125
+ # Randomly sample an equal number of images for the second class
3126
+ sampled_alt_class_paths = random.sample(alt_class_paths, min(count_target_class, len(alt_class_paths)))
3127
+ print(f'Sampled {len(sampled_alt_class_paths)} alternative images for balancing')
3128
+
3129
+ # Append this list as the second class
3130
+ class_paths.append(sampled_alt_class_paths)
3131
+
3132
+ print(f'Generated a list of lists from annotation of {len(class_paths)} classes')
3133
+ for i, ls in enumerate(class_paths):
3134
+ print(f'Class {i}: {len(ls)} images')
3135
+
3136
+ return class_paths
3198
3137
 
3138
+ def training_dataset_from_annotation_v2(db_path, dst, annotation_column='test', annotated_classes=(1, 2)):
3139
+ all_paths = []
3140
+
3141
+ # Connect to the database and retrieve the image paths and annotations
3142
+ print(f'Reading DataBase: {db_path}')
3143
+ with sqlite3.connect(db_path) as conn:
3144
+ cursor = conn.cursor()
3145
+ # Retrieve all paths and annotations from the database
3146
+ query = f"SELECT png_path, {annotation_column} FROM png_list"
3147
+ cursor.execute(query)
3148
+
3199
3149
  while True:
3200
3150
  rows = cursor.fetchmany(1000)
3201
3151
  if not rows:
@@ -3203,13 +3153,36 @@ def training_dataset_from_annotation(db_path, dst, annotation_column='test', ann
3203
3153
  for row in rows:
3204
3154
  all_paths.append(row)
3205
3155
 
3206
- # Filter paths based on annotation
3156
+ print('Total paths retrieved:', len(all_paths))
3157
+
3158
+ # Filter paths based on annotated_classes
3207
3159
  class_paths = []
3208
3160
  for class_ in annotated_classes:
3209
3161
  class_paths_temp = [path for path, annotation in all_paths if annotation == class_]
3210
3162
  class_paths.append(class_paths_temp)
3163
+ print(f'Found {len(class_paths_temp)} images in class {class_}')
3164
+
3165
+ # If only one class is provided, create an alternative list by sampling paths from all_paths that are not in the annotated class
3166
+ if len(annotated_classes) == 1:
3167
+ target_class = annotated_classes[0]
3168
+ count_target_class = len(class_paths[0])
3169
+ print(f'Annotated class: {target_class} with {count_target_class} images')
3170
+
3171
+ # Filter all_paths to exclude paths that belong to the target class
3172
+ alt_class_paths = [path for path, annotation in all_paths if annotation != target_class]
3173
+ print('Alternative paths available:', len(alt_class_paths))
3174
+
3175
+ # Randomly sample an equal number of images for the second class
3176
+ sampled_alt_class_paths = random.sample(alt_class_paths, min(count_target_class, len(alt_class_paths)))
3177
+ print(f'Sampled {len(sampled_alt_class_paths)} alternative images for balancing')
3178
+
3179
+ # Append this list as the second class
3180
+ class_paths.append(sampled_alt_class_paths)
3211
3181
 
3212
3182
  print(f'Generated a list of lists from annotation of {len(class_paths)} classes')
3183
+ for i, ls in enumerate(class_paths):
3184
+ print(f'Class {i}: {len(ls)} images')
3185
+
3213
3186
  return class_paths
3214
3187
 
3215
3188
  def generate_dataset_from_lists(dst, class_data, classes, test_split=0.1):
@@ -3228,8 +3201,9 @@ def generate_dataset_from_lists(dst, class_data, classes, test_split=0.1):
3228
3201
  test_class_dir = os.path.join(dst, f'test/{cls}')
3229
3202
  os.makedirs(train_class_dir, exist_ok=True)
3230
3203
  os.makedirs(test_class_dir, exist_ok=True)
3231
-
3204
+
3232
3205
  # Split the data
3206
+ print('data',len(data), test_split)
3233
3207
  train_data, test_data = train_test_split(data, test_size=test_split, shuffle=True, random_state=42)
3234
3208
 
3235
3209
  # Copy train files
spacr/measure.py CHANGED
@@ -16,6 +16,7 @@ from skimage.util import img_as_bool
16
16
  import matplotlib.pyplot as plt
17
17
  from math import ceil, sqrt
18
18
 
19
+
19
20
  def get_components(cell_mask, nucleus_mask, pathogen_mask):
20
21
  """
21
22
  Get the components (nucleus and pathogens) for each cell in the given masks.
@@ -761,12 +762,10 @@ def _measure_crop_core(index, time_ls, file, settings):
761
762
  if settings['cytoplasm_min_size'] is not None and settings['cytoplasm_min_size'] != 0:
762
763
  cytoplasm_mask = _filter_object(cytoplasm_mask, settings['cytoplasm_min_size'])
763
764
 
764
- if settings['cell_mask_dim'] is not None:
765
+ if settings['cell_mask_dim'] is not None and settings['nucleus_mask_dim'] is not None and settings['pathogen_mask_dim'] is not None:
765
766
  cell_mask, nucleus_mask, pathogen_mask, cytoplasm_mask = _exclude_objects(cell_mask, nucleus_mask, pathogen_mask, cytoplasm_mask, uninfected=settings['uninfected'])
766
-
767
- # Update data with the new masks
768
- if settings['cell_mask_dim'] is not None:
769
767
  data[:, :, settings['cell_mask_dim']] = cell_mask.astype(data_type)
768
+
770
769
  if settings['nucleus_mask_dim'] is not None:
771
770
  data[:, :, settings['nucleus_mask_dim']] = nucleus_mask.astype(data_type)
772
771
  if settings['pathogen_mask_dim'] is not None:
@@ -779,7 +778,6 @@ def _measure_crop_core(index, time_ls, file, settings):
779
778
  figs[f'{file_name}__after_filtration'] = fig
780
779
 
781
780
  if settings['save_measurements']:
782
-
783
781
  cell_df, nucleus_df, pathogen_df, cytoplasm_df = _morphological_measurements(cell_mask, nucleus_mask, pathogen_mask, cytoplasm_mask, settings)
784
782
 
785
783
  #if settings['skeleton']:
@@ -789,7 +787,6 @@ def _measure_crop_core(index, time_ls, file, settings):
789
787
  cell_intensity_df, nucleus_intensity_df, pathogen_intensity_df, cytoplasm_intensity_df = _intensity_measurements(cell_mask, nucleus_mask, pathogen_mask, cytoplasm_mask, channel_arrays, settings, sizes=[1, 2, 3, 4, 5], periphery=True, outside=True)
790
788
  if settings['cell_mask_dim'] is not None:
791
789
  cell_merged_df = _merge_and_save_to_database(cell_df, cell_intensity_df, 'cell', source_folder, file_name, settings['experiment'], settings['timelapse'])
792
-
793
790
  if settings['nucleus_mask_dim'] is not None:
794
791
  nucleus_merged_df = _merge_and_save_to_database(nucleus_df, nucleus_intensity_df, 'nucleus', source_folder, file_name, settings['experiment'], settings['timelapse'])
795
792
 
@@ -800,7 +797,6 @@ def _measure_crop_core(index, time_ls, file, settings):
800
797
  cytoplasm_merged_df = _merge_and_save_to_database(cytoplasm_df, cytoplasm_intensity_df, 'cytoplasm', source_folder, file_name, settings['experiment'], settings['timelapse'])
801
798
 
802
799
  if settings['save_png'] or settings['save_arrays'] or settings['plot']:
803
-
804
800
  if isinstance(settings['dialate_pngs'], bool):
805
801
  dialate_pngs = [settings['dialate_pngs'], settings['dialate_pngs'], settings['dialate_pngs']]
806
802
  if isinstance(settings['dialate_pngs'], list):
@@ -825,13 +821,15 @@ def _measure_crop_core(index, time_ls, file, settings):
825
821
 
826
822
  if len(crop_ls) != len(size_ls):
827
823
  print(f"Setting: size_ls: {settings['png_size']} should be a list of integers, or a list of lists of integers if crop_ls: {settings['crop_mode']} has multiple elements")
828
-
824
+
829
825
  for crop_idx, crop_mode in enumerate(crop_ls):
830
826
  width, height = size_ls[crop_idx]
827
+
831
828
  if crop_mode == 'cell':
832
829
  crop_mask = cell_mask.copy()
833
830
  dialate_png = dialate_pngs[crop_idx]
834
831
  dialate_png_ratio = dialate_png_ratios[crop_idx]
832
+
835
833
  elif crop_mode == 'nucleus':
836
834
  crop_mask = nucleus_mask.copy()
837
835
  dialate_png = dialate_pngs[crop_idx]
@@ -852,7 +850,7 @@ def _measure_crop_core(index, time_ls, file, settings):
852
850
 
853
851
  for _id in objects_in_image:
854
852
 
855
- region = (crop_mask == _id) # This creates a boolean mask for the region of interest
853
+ region = (crop_mask == _id)
856
854
 
857
855
  # Use the boolean mask to filter the cell_mask and then find unique IDs
858
856
  region_cell_ids = np.atleast_1d(np.unique(cell_mask[region]))
@@ -947,7 +945,7 @@ def measure_crop(settings):
947
945
 
948
946
  from .io import _save_settings_to_db
949
947
  from .timelapse import _timelapse_masks_to_gif
950
- from .utils import measure_test_mode, print_progress
948
+ from .utils import measure_test_mode, print_progress, save_settings
951
949
  from .settings import get_measure_crop_settings
952
950
 
953
951
  if not isinstance(settings['src'], (str, list)):
@@ -1032,9 +1030,10 @@ def measure_crop(settings):
1032
1030
  settings['crop_mode'] = [settings['crop_mode']]
1033
1031
  settings['crop_mode'] = [str(crop_mode) for crop_mode in settings['crop_mode']]
1034
1032
  print(f"Converted crop_mode to list: {settings['crop_mode']}")
1035
- return
1036
1033
 
1037
1034
  _save_settings_to_db(settings)
1035
+ #save_settings(settings, name='measure_crop', show=True)
1036
+
1038
1037
  files = [f for f in os.listdir(settings['src']) if f.endswith('.npy')]
1039
1038
  n_jobs = settings['n_jobs']
1040
1039
  print(f'using {n_jobs} cpu cores')