spacr 0.4.15__py3-none-any.whl → 0.4.60__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/io.py CHANGED
@@ -23,6 +23,7 @@ import seaborn as sns
23
23
  from nd2reader import ND2Reader
24
24
  from torchvision import transforms
25
25
  from sklearn.model_selection import train_test_split
26
+ import readlif
26
27
 
27
28
  def process_non_tif_non_2D_images(folder):
28
29
  """Processes all images in the folder and splits them into grayscale channels, preserving bit depth."""
@@ -131,58 +132,61 @@ def process_non_tif_non_2D_images(folder):
131
132
 
132
133
  def _load_images_and_labels(image_files, label_files, invert=False):
133
134
 
134
- from .utils import invert_image, apply_mask
135
+ from .utils import invert_image
135
136
 
136
137
  images = []
137
138
  labels = []
138
-
139
- if not image_files is None:
140
- image_names = sorted([os.path.basename(f) for f in image_files])
141
- else:
142
- image_names = []
143
-
144
- if not label_files is None:
145
- label_names = sorted([os.path.basename(f) for f in label_files])
146
- else:
147
- label_names = []
148
139
 
149
- if not image_files is None and not label_files is None:
140
+ image_names = sorted([os.path.basename(f) for f in image_files]) if image_files else []
141
+ label_names = sorted([os.path.basename(f) for f in label_files]) if label_files else []
142
+
143
+ if image_files and label_files:
150
144
  for img_file, lbl_file in zip(image_files, label_files):
151
145
  image = cellpose.io.imread(img_file)
146
+ if image is None:
147
+ print(f"WARNING: Could not load image: {img_file}")
148
+ continue
152
149
  if invert:
153
150
  image = invert_image(image)
154
- label = cellpose.io.imread(lbl_file)
155
151
  if image.max() > 1:
156
152
  image = image / image.max()
153
+
154
+ label = cellpose.io.imread(lbl_file)
155
+ if label is None:
156
+ print(f"WARNING: Could not load label: {lbl_file}")
157
+ continue
158
+
157
159
  images.append(image)
158
160
  labels.append(label)
159
- elif not image_files is None:
161
+
162
+ elif image_files:
160
163
  for img_file in image_files:
161
164
  image = cellpose.io.imread(img_file)
165
+ if image is None:
166
+ print(f"WARNING: Could not load image: {img_file}")
167
+ continue
162
168
  if invert:
163
169
  image = invert_image(image)
164
170
  if image.max() > 1:
165
171
  image = image / image.max()
166
172
  images.append(image)
167
- elif not image_files is None:
168
- for lbl_file in label_files:
169
- label = cellpose.io.imread(lbl_file)
173
+
174
+ elif label_files:
175
+ for lbl_file in label_files:
176
+ label = cellpose.io.imread(lbl_file)
177
+ if label is None:
178
+ print(f"WARNING: Could not load label: {lbl_file}")
179
+ continue
170
180
  labels.append(label)
171
-
172
- if not image_files is None:
173
- image_dir = os.path.dirname(image_files[0])
174
- else:
175
- image_dir = None
176
-
177
- if not label_files is None:
178
- label_dir = os.path.dirname(label_files[0])
179
- else:
180
- label_dir = None
181
-
182
- # Log the number of loaded images and labels
181
+
182
+ image_dir = os.path.dirname(image_files[0]) if image_files else None
183
+ label_dir = os.path.dirname(label_files[0]) if label_files else None
184
+
183
185
  print(f'Loaded {len(images)} images and {len(labels)} labels from {image_dir} and {label_dir}')
184
- if len(labels) > 0 and len(images) > 0:
185
- print(f'image shape: {images[0].shape}, image type: images[0].shape mask shape: {labels[0].shape}, image type: labels[0].shape')
186
+ if images and labels:
187
+ print(f'image shape: {images[0].shape}, image type: {images[0].dtype}; '
188
+ f'label shape: {labels[0].shape}, label type: {labels[0].dtype}')
189
+
186
190
  return images, labels, image_names, label_names
187
191
 
188
192
  def _load_normalized_images_and_labels(image_files, label_files, channels=None, percentiles=None,
@@ -647,7 +651,7 @@ def load_images_from_paths(images_by_key):
647
651
 
648
652
  return images_dict
649
653
 
650
- #@log_function_call
654
+ #@log_function_call
651
655
  def _rename_and_organize_image_files(src, regex, batch_size=100, pick_slice=False, skip_mode='01', metadata_type='', img_format='.tif'):
652
656
  """
653
657
  Convert z-stack images to maximum intensity projection (MIP) images.
@@ -664,13 +668,16 @@ def _rename_and_organize_image_files(src, regex, batch_size=100, pick_slice=Fals
664
668
  None
665
669
  """
666
670
 
671
+ if isinstance(img_format, str):
672
+ img_format = [img_format]
673
+
667
674
  from .utils import _extract_filename_metadata, print_progress
668
675
 
669
676
  regular_expression = re.compile(regex)
670
677
  stack_path = os.path.join(src, 'stack')
671
678
  files_processed = 0
672
679
  if not os.path.exists(stack_path) or (os.path.isdir(stack_path) and len(os.listdir(stack_path)) == 0):
673
- all_filenames = [filename for filename in os.listdir(src) if filename.endswith(img_format)]
680
+ all_filenames = [filename for filename in os.listdir(src) if any(filename.endswith(ext) for ext in img_format)]
674
681
  print(f'All files: {len(all_filenames)} in {src}')
675
682
  time_ls = []
676
683
  image_paths_by_key = _extract_filename_metadata(all_filenames, src, regular_expression, metadata_type, pick_slice, skip_mode)
@@ -729,11 +736,11 @@ def _rename_and_organize_image_files(src, regex, batch_size=100, pick_slice=Fals
729
736
  images_by_key.clear()
730
737
 
731
738
  # Move original images to a new directory
732
- valid_exts = [img_format]
733
739
  newpath = os.path.join(src, 'orig')
734
740
  os.makedirs(newpath, exist_ok=True)
735
741
  for filename in os.listdir(src):
736
- if os.path.splitext(filename)[1] in valid_exts:
742
+ #print(f"{filename}: {os.path.splitext(filename)[1]}")
743
+ if os.path.splitext(filename)[1] in img_format:
737
744
  move = os.path.join(newpath, filename)
738
745
  if os.path.exists(move):
739
746
  print(f'WARNING: A file with the same name already exists at location {move}')
@@ -1571,7 +1578,7 @@ def preprocess_img_data(settings):
1571
1578
  Returns:
1572
1579
  None
1573
1580
  """
1574
-
1581
+
1575
1582
  src = settings['src']
1576
1583
  valid_ext = ['tif', 'tiff', 'png', 'jpeg']
1577
1584
  files = os.listdir(src)
@@ -1599,11 +1606,11 @@ def preprocess_img_data(settings):
1599
1606
 
1600
1607
  mask_channels = [settings['nucleus_channel'], settings['cell_channel'], settings['pathogen_channel']]
1601
1608
  backgrounds = [settings['nucleus_background'], settings['cell_background'], settings['pathogen_background']]
1602
-
1609
+
1603
1610
  settings, metadata_type, custom_regex, nr, plot, batch_size, timelapse, lower_percentile, randomize, all_to_mip, pick_slice, skip_mode, cmap, figuresize, normalize, save_dtype, test_mode, test_images, random_test = set_default_settings_preprocess_img_data(settings)
1604
-
1611
+
1605
1612
  regex = _get_regex(metadata_type, img_format, custom_regex)
1606
-
1613
+
1607
1614
  if test_mode:
1608
1615
 
1609
1616
  print(f'Running spacr in test mode')
@@ -1612,6 +1619,8 @@ def preprocess_img_data(settings):
1612
1619
  os.rmdir(os.path.join(src, 'test'))
1613
1620
  print(f"Deleted test directory: {os.path.join(src, 'test')}")
1614
1621
  except OSError as e:
1622
+ print(f"Error deleting test directory: {e}")
1623
+ print(f"Delete manually before running test mode")
1615
1624
  pass
1616
1625
 
1617
1626
  src = _run_test_mode(settings['src'], regex, timelapse, test_images, random_test)
@@ -1628,6 +1637,7 @@ def preprocess_img_data(settings):
1628
1637
  if timelapse:
1629
1638
  _move_to_chan_folder(src, regex, timelapse, metadata_type)
1630
1639
  else:
1640
+ img_format = ['.tif', '.tiff', '.png', '.jpg', '.jpeg', '.bmp', '.nd2', '.czi', '.lif']
1631
1641
  _rename_and_organize_image_files(src, regex, batch_size, pick_slice, skip_mode, metadata_type, img_format)
1632
1642
 
1633
1643
  #Make sure no batches will be of only one image
@@ -1650,7 +1660,7 @@ def preprocess_img_data(settings):
1650
1660
  if len(settings['channels']) != nr_channel_folders:
1651
1661
  print(f"Number of channels does not match number of channel folders. channels: {settings['channels']} channel folders: {nr_channel_folders}")
1652
1662
  new_channels = list(range(nr_channel_folders))
1653
- print(f"Setting channels to {new_channels}")
1663
+ print(f"Changing channels from {settings['channels']} to {new_channels}")
1654
1664
  settings['channels'] = new_channels
1655
1665
 
1656
1666
  if timelapse:
@@ -1780,11 +1790,11 @@ def _read_and_join_tables(db_path, table_names=['cell', 'cytoplasm', 'nucleus',
1780
1790
  print(e)
1781
1791
  conn.close()
1782
1792
  if 'png_list' in dataframes:
1783
- png_list_df = dataframes['png_list'][['cell_id', 'png_path', 'plate', 'row_name', 'column_name', 'field']].copy()
1793
+ png_list_df = dataframes['png_list'][['cell_id', 'png_path', 'plateID', 'rowID', 'columnID', 'fieldID']].copy()
1784
1794
  png_list_df['cell_id'] = png_list_df['cell_id'].str[1:].astype(int)
1785
1795
  png_list_df.rename(columns={'cell_id': 'object_label'}, inplace=True)
1786
1796
  if 'cell' in dataframes:
1787
- join_cols = ['object_label', 'plate', 'row_name', 'column_name','field']
1797
+ join_cols = ['object_label', 'plateID', 'rowID', 'columnID','fieldID']
1788
1798
  dataframes['cell'] = pd.merge(dataframes['cell'], png_list_df, on=join_cols, how='left')
1789
1799
  else:
1790
1800
  print("Cell table not found in database tables.")
@@ -2085,14 +2095,18 @@ def _read_db(db_loc, tables):
2085
2095
  Returns:
2086
2096
  - dfs (list): A list of pandas DataFrames, each containing the data from a table.
2087
2097
  """
2088
- from .utils import rename_columns_in_db
2098
+ from .utils import rename_columns_in_db, correct_metadata
2099
+
2089
2100
  rename_columns_in_db(db_loc)
2090
2101
  conn = sqlite3.connect(db_loc)
2091
2102
  dfs = []
2103
+
2092
2104
  for table in tables:
2093
2105
  query = f'SELECT * FROM {table}'
2094
2106
  df = pd.read_sql_query(query, conn)
2107
+ df = correct_metadata(df)
2095
2108
  dfs.append(df)
2109
+
2096
2110
  conn.close()
2097
2111
  return dfs
2098
2112
 
@@ -2271,7 +2285,7 @@ def _copy_missclassified(df):
2271
2285
 
2272
2286
  def _read_db(db_loc, tables):
2273
2287
 
2274
- from .utils import rename_columns_in_db
2288
+ from .utils import rename_columns_in_db, correct_metadata
2275
2289
 
2276
2290
  rename_columns_in_db(db_loc)
2277
2291
  conn = sqlite3.connect(db_loc) # Create a connection to the database
@@ -2279,12 +2293,13 @@ def _read_db(db_loc, tables):
2279
2293
  for table in tables:
2280
2294
  query = f'SELECT * FROM {table}' # Write a SQL query to get the data from the database
2281
2295
  df = pd.read_sql_query(query, conn) # Use the read_sql_query function to get the data and save it as a DataFrame
2296
+ df = correct_metadata(df)
2282
2297
  dfs.append(df)
2283
2298
  conn.close() # Close the connection
2284
2299
  return dfs
2285
2300
 
2286
2301
  def _read_and_merge_data(locs, tables, verbose=False, nuclei_limit=10, pathogen_limit=10, change_plate=False):
2287
- from .io import _read_db
2302
+
2288
2303
  from .utils import _split_data
2289
2304
 
2290
2305
  # Initialize an empty dictionary to store DataFrames by table name
@@ -2294,8 +2309,8 @@ def _read_and_merge_data(locs, tables, verbose=False, nuclei_limit=10, pathogen_
2294
2309
  for idx, loc in enumerate(locs):
2295
2310
  db_dfs = _read_db(loc, tables)
2296
2311
  if change_plate:
2297
- db_dfs['plate'] = f'plate{idx+1}'
2298
- db_dfs['prc'] = db_dfs['plate'].astype(str) + '_' + db_dfs['row_name'].astype(str) + '_' + db_dfs['column_name'].astype(str)
2312
+ db_dfs['plateID'] = f'plate{idx+1}'
2313
+ db_dfs['prc'] = db_dfs['plateID'].astype(str) + '_' + db_dfs['rowID'].astype(str) + '_' + db_dfs['columnID'].astype(str)
2299
2314
  for table, df in zip(tables, db_dfs):
2300
2315
  data_dict[table].append(df)
2301
2316
 
@@ -2303,6 +2318,7 @@ def _read_and_merge_data(locs, tables, verbose=False, nuclei_limit=10, pathogen_
2303
2318
  for table, dfs in data_dict.items():
2304
2319
  if dfs:
2305
2320
  data_dict[table] = pd.concat(dfs, axis=0)
2321
+
2306
2322
  if verbose:
2307
2323
  print(f"{table}: {len(data_dict[table])}")
2308
2324
 
@@ -2389,18 +2405,18 @@ def _read_and_merge_data(locs, tables, verbose=False, nuclei_limit=10, pathogen_
2389
2405
  if 'png_list' in data_dict:
2390
2406
  png_list = data_dict['png_list'].copy()
2391
2407
  png_list_g_df_numeric, png_list_g_df_non_numeric = _split_data(png_list, 'prcfo', 'cell_id')
2392
- png_list_g_df_non_numeric.drop(columns=['plate','row_name','column_name','field','file_name','cell_id', 'prcf'], inplace=True)
2408
+ png_list_g_df_non_numeric.drop(columns=['plateID','rowID','columnID','fieldID','file_name','cell_id', 'prcf'], inplace=True)
2393
2409
  if verbose:
2394
2410
  print(f'png_list: {len(png_list)}, png_list grouped: {len(png_list_g_df_numeric)}')
2395
2411
  print(f"Added png_list columns: {png_list_g_df_numeric.columns}, {png_list_g_df_non_numeric.columns}")
2396
2412
  merged_df = merged_df.merge(png_list_g_df_numeric, left_index=True, right_index=True)
2397
2413
  merged_df = merged_df.merge(png_list_g_df_non_numeric, left_index=True, right_index=True)
2398
-
2414
+
2399
2415
  # Add prc (plate row column) and prcfo (plate row column field object) columns
2400
- metadata = metadata.assign(prc=lambda x: x['plate'] + '_' + x['row_name'] + '_' + x['column_name'])
2416
+ metadata = metadata.assign(prc=lambda x: x['plateID'] + '_' + x['rowID'] + '_' + x['columnID'])
2401
2417
  cells_well = metadata.groupby('prc')['object_label'].nunique().reset_index(name='cells_per_well')
2402
2418
  metadata = metadata.merge(cells_well, on='prc')
2403
- metadata = metadata.assign(prcfo=lambda x: x['plate'] + '_' + x['row_name'] + '_' + x['column_name'] + '_' + x['field'] + '_' + x['object_label'])
2419
+ metadata = metadata.assign(prcfo=lambda x: x['plateID'] + '_' + x['rowID'] + '_' + x['columnID'] + '_' + x['fieldID'] + '_' + x['object_label'])
2404
2420
  metadata.set_index('prcfo', inplace=True)
2405
2421
 
2406
2422
  # Merge metadata with final merged DataFrame
@@ -2988,7 +3004,7 @@ def training_dataset_from_annotation(db_path, dst, annotation_column='test', ann
2988
3004
 
2989
3005
  return class_paths
2990
3006
 
2991
- def training_dataset_from_annotation_metadata(db_path, dst, annotation_column='test', annotated_classes=(1, 2), metadata_type_by='column_name', class_metadata=['c1','c2']):
3007
+ def training_dataset_from_annotation_metadata(db_path, dst, annotation_column='test', annotated_classes=(1, 2), metadata_type_by='columnID', class_metadata=['c1','c2']):
2992
3008
  all_paths = []
2993
3009
 
2994
3010
  # Connect to the database and retrieve the image paths and annotations
@@ -3010,9 +3026,9 @@ def training_dataset_from_annotation_metadata(db_path, dst, annotation_column='t
3010
3026
 
3011
3027
  # Filter all_paths by metadata_type_by and class_metadata
3012
3028
  filtered_paths = []
3013
- metadata_index = {'row_name': 2, 'column_name': 3}.get(metadata_type_by, None)
3029
+ metadata_index = {'rowID': 2, 'columnID': 3}.get(metadata_type_by, None)
3014
3030
  if metadata_index is None:
3015
- raise ValueError(f"Invalid metadata_type_by value: {metadata_type_by}. Must be 'row_name' or 'column_name'. {class_metadata} must be a list formatted as ['c1', 'c2'] or ['r1', 'r2']")
3031
+ raise ValueError(f"Invalid metadata_type_by value: {metadata_type_by}. Must be 'rowID' or 'columnID'. {class_metadata} must be a list formatted as ['c1', 'c2'] or ['r1', 'r2']")
3016
3032
 
3017
3033
  for row in all_paths:
3018
3034
  if row[metadata_index] in class_metadata:
@@ -3102,4 +3118,473 @@ def generate_dataset_from_lists(dst, class_data, classes, test_split=0.1):
3102
3118
  test_class_dir = os.path.join(dst, f'test/{cls}')
3103
3119
  print(f'Train class {cls}: {len(os.listdir(train_class_dir))}, Test class {cls}: {len(os.listdir(test_class_dir))}')
3104
3120
 
3105
- return os.path.join(dst, 'train'), os.path.join(dst, 'test')
3121
+ return os.path.join(dst, 'train'), os.path.join(dst, 'test')
3122
+
3123
+ def convert_separate_files_to_yokogawa(folder, regex):
3124
+
3125
+ ROWS = "ABCDEFGHIJKLMNOP"
3126
+ COLS = [f"{i:02d}" for i in range(1, 25)]
3127
+ WELLS = [f"{r}{c}" for r in ROWS for c in COLS]
3128
+
3129
+ def _get_next_well(used_wells):
3130
+ plate = 1
3131
+ for well in WELLS:
3132
+ well_name = f"plate{plate}_{well}"
3133
+ if well_name not in used_wells:
3134
+ return well_name
3135
+ if well == "P24":
3136
+ plate += 1
3137
+ return f"plate{plate}_A01"
3138
+
3139
+ pattern = re.compile(regex, re.I)
3140
+
3141
+ files_by_region = {}
3142
+ rename_log = []
3143
+ csv_path = os.path.join(folder, "rename_log.csv")
3144
+ used_wells = set()
3145
+ region_to_well = {}
3146
+
3147
+ # Group files by (plateID, wellID, fieldID, timeID, chanID)
3148
+ for file in os.listdir(folder):
3149
+ match = pattern.match(file)
3150
+ if not match:
3151
+ print(f"Skipping {file}: does not match regex.")
3152
+ continue
3153
+
3154
+ meta = match.groupdict()
3155
+
3156
+ # Mandatory metadata
3157
+ if 'wellID' not in meta or meta['wellID'] is None:
3158
+ print(f"Skipping {file}: missing mandatory wellID.")
3159
+ continue
3160
+ wellID = meta['wellID']
3161
+
3162
+ # Optional metadata with defaults
3163
+ plateID = meta.get('plateID', '1') or '1'
3164
+ fieldID = meta.get('fieldID', '1') or '1'
3165
+ timeID = int(meta.get('timeID', 1) or 1)
3166
+ chanID = int(meta.get('chanID', 1) or 1)
3167
+ sliceID = meta.get('sliceID')
3168
+ sliceID = int(sliceID) if sliceID is not None else None
3169
+
3170
+ region_key = (plateID, wellID, fieldID, timeID, chanID)
3171
+
3172
+ files_by_region.setdefault(region_key, []).append((file, sliceID))
3173
+
3174
+ # Assign wells and process files per region
3175
+ for region, file_list in files_by_region.items():
3176
+ if region[:3] not in region_to_well:
3177
+ next_well = _get_next_well(used_wells)
3178
+ region_to_well[region[:3]] = next_well
3179
+ used_wells.add(next_well)
3180
+
3181
+ assigned_well = region_to_well[region[:3]]
3182
+ plateID, wellID, fieldID, timeID, chanID = region
3183
+
3184
+ # Check if multiple slices exist and are meaningful
3185
+ slice_ids = [sid for _, sid in file_list if sid is not None]
3186
+ unique_slices = set(slice_ids)
3187
+
3188
+ images = []
3189
+ for filename, _ in sorted(file_list, key=lambda x: x[1] or 1):
3190
+ img = tifffile.imread(os.path.join(folder, filename))
3191
+ images.append(img)
3192
+
3193
+ # Perform MIP only if multiple unique slices are present
3194
+ if len(unique_slices) > 1:
3195
+ img_to_save = np.max(np.stack(images), axis=0)
3196
+ else:
3197
+ img_to_save = images[0]
3198
+
3199
+ dtype = img_to_save.dtype
3200
+
3201
+ new_filename = f"{assigned_well}_T{timeID:04d}F{int(fieldID):03d}L01C{chanID:02d}.tif"
3202
+ new_filepath = os.path.join(folder, new_filename)
3203
+ tifffile.imwrite(new_filepath, img_to_save.astype(dtype))
3204
+
3205
+ # Log original filenames involved in MIP or single file rename
3206
+ original_files = ";".join(f[0] for f in file_list)
3207
+ rename_log.append({"Original File(s)": original_files, "Renamed TIFF": new_filename})
3208
+
3209
+ pd.DataFrame(rename_log).to_csv(csv_path, index=False)
3210
+ print(f"Processing complete. Files saved in {folder} and rename log saved as {csv_path}.")
3211
+
3212
+ def convert_to_yokogawa(folder):
3213
+ """
3214
+ Detects file type in the folder and converts them
3215
+ to Yokogawa-style naming with Maximum Intensity Projection (MIP).
3216
+ """
3217
+
3218
+ #def _get_next_well(used_wells):
3219
+ # """
3220
+ # Determines the next available well position in a 384-well format.
3221
+ # Iterates wells, and after P24, switches to plate2.
3222
+ # """
3223
+ # plate = 1
3224
+ # for well in WELLS:
3225
+ # well_name = f"plate{plate}_{well}"
3226
+ # if well_name not in used_wells:
3227
+ # used_wells.add(well_name)
3228
+ # return well_name
3229
+ # if well == "P24":
3230
+ # plate += 1
3231
+ # raise ValueError("All wells exhausted.")
3232
+
3233
+ def _get_next_well(used_wells):
3234
+ """
3235
+ Determines the next available well position across multiple 384-well plates.
3236
+ """
3237
+ ROWS = "ABCDEFGHIJKLMNOP"
3238
+ COLS = [f"{i:02d}" for i in range(1, 25)]
3239
+ WELLS = [f"{r}{c}" for r in ROWS for c in COLS]
3240
+
3241
+ plate = 1
3242
+ while True:
3243
+ for well in WELLS:
3244
+ well_name = f"plate{plate}_{well}"
3245
+ if well_name not in used_wells:
3246
+ used_wells.add(well_name)
3247
+ return well_name
3248
+ plate += 1 # All wells exhausted in current plate, increment to next plate
3249
+
3250
+
3251
+ # Define 384-well plate format
3252
+ ROWS = "ABCDEFGHIJKLMNOP"
3253
+ COLS = [f"{i:02d}" for i in range(1, 25)]
3254
+ WELLS = [f"{r}{c}" for r in ROWS for c in COLS]
3255
+
3256
+ filenames = []
3257
+ rename_log = []
3258
+ csv_path = os.path.join(folder, "rename_log.csv")
3259
+ used_wells = set()
3260
+
3261
+ # **Dictionary to store well assignments per original file**
3262
+ file_to_well = {}
3263
+
3264
+ for file in os.listdir(folder):
3265
+ path = os.path.join(folder, file)
3266
+ ext = file.lower().split('.')[-1]
3267
+
3268
+ # **Assign a well only once per original file**
3269
+ if file not in file_to_well:
3270
+ file_to_well[file] = _get_next_well(used_wells)
3271
+ #used_wells.add(file_to_well[file]) # Mark it as used
3272
+
3273
+ well = file_to_well[file] # Use the same well for all channels/times
3274
+
3275
+ ### **Process Nikon ND2 Files**
3276
+ if ext == 'nd2':
3277
+ try:
3278
+ nd2 = ND2Reader(path)
3279
+ metadata = nd2.metadata
3280
+
3281
+ timepoints = list(range(len(metadata.get("frames", [0])))) or [0]
3282
+ fields = list(range(len(metadata.get("fields_of_view", [0])))) or [0]
3283
+ z_levels = list(metadata.get("z_levels", range(1))) if metadata.get("z_levels") else [0]
3284
+ channels = metadata.get("channels", [])
3285
+
3286
+ for t_idx in timepoints:
3287
+ for f_idx in fields:
3288
+ for c_idx, channel in enumerate(channels):
3289
+ try:
3290
+ mip_image = np.max.reduce([
3291
+ nd2.get_frame_2D(t=t_idx, v=f_idx, z=z_idx, c=c_idx)
3292
+ for z_idx in z_levels
3293
+ ], axis=0)
3294
+
3295
+ dtype = mip_image.dtype
3296
+ filename = f"{well}_T{t_idx+1:04d}F{f_idx+1:03d}L01C{c_idx+1:02d}.tif"
3297
+ filepath = os.path.join(folder, filename)
3298
+
3299
+ tifffile.imwrite(filepath, mip_image.astype(dtype))
3300
+ rename_log.append({"Original File": file, "Renamed TIFF": filename})
3301
+
3302
+ except IndexError:
3303
+ print(f"Warning: ND2 file {file} has an incomplete data structure. Skipping.")
3304
+
3305
+ except Exception as e:
3306
+ print(f"Error processing ND2 file {file}: {e}")
3307
+
3308
+ ### **Process Zeiss CZI Files**
3309
+ elif ext == 'czi':
3310
+ with czifile.CziFile(path) as czi:
3311
+ img_data = czi.asarray() # Read the full image array
3312
+
3313
+ # Remove singleton dimensions (if any)
3314
+ img_data = np.squeeze(img_data)
3315
+
3316
+ # Get the actual shape of the data
3317
+ shape = img_data.shape
3318
+ num_dims = len(shape)
3319
+
3320
+ # Default values if dimensions are missing
3321
+ timepoints = 1
3322
+ z_levels = 1
3323
+ channels = 1
3324
+
3325
+ # Determine dimension mapping dynamically
3326
+ if num_dims == 2: # (Y, X) → Single 2D image
3327
+ y_dim, x_dim = shape
3328
+ img_data = img_data.reshape(1, 1, 1, y_dim, x_dim) # Add missing dimensions
3329
+ elif num_dims == 3: # (C, Y, X) or (Z, Y, X)
3330
+ if shape[0] <= 4: # Likely (C, Y, X)
3331
+ channels, y_dim, x_dim = shape
3332
+ img_data = img_data.reshape(1, 1, channels, y_dim, x_dim) # Add missing dimensions
3333
+ else: # Likely (Z, Y, X)
3334
+ z_levels, y_dim, x_dim = shape
3335
+ img_data = img_data.reshape(1, z_levels, 1, y_dim, x_dim) # Add missing dimensions
3336
+ elif num_dims == 4: # Could be (T, C, Y, X) or (T, Z, Y, X) or (Z, C, Y, X)
3337
+ if shape[1] <= 4: # Assume (T, C, Y, X)
3338
+ timepoints, channels, y_dim, x_dim = shape
3339
+ img_data = img_data.reshape(timepoints, 1, channels, y_dim, x_dim) # Add missing Z
3340
+ else: # Assume (T, Z, Y, X) or (Z, C, Y, X)
3341
+ timepoints, z_levels, y_dim, x_dim = shape
3342
+ img_data = img_data.reshape(timepoints, z_levels, 1, y_dim, x_dim) # Add missing C
3343
+ elif num_dims == 5: # Standard (T, Z, C, Y, X)
3344
+ timepoints, z_levels, channels, y_dim, x_dim = shape
3345
+ else:
3346
+ raise ValueError(f"Unexpected CZI shape: {shape}. Unable to process.")
3347
+
3348
+ # Iterate over detected timepoints, channels, and perform MIP over Z
3349
+ for t_idx in range(timepoints):
3350
+ for c_idx in range(channels):
3351
+ # Extract Z-stack or single image
3352
+ if z_levels > 1:
3353
+ z_stack = img_data[t_idx, :, c_idx, :, :] # MIP over Z
3354
+ mip_image = np.max(z_stack, axis=0)
3355
+ else:
3356
+ mip_image = img_data[t_idx, 0, c_idx, :, :] # No Z, take directly
3357
+
3358
+ # Ensure correct dtype
3359
+ dtype = mip_image.dtype
3360
+
3361
+ # Generate Yokogawa-style filename
3362
+ filename = f"{well}_T{t_idx+1:04d}F001L01C{c_idx+1:02d}.tif"
3363
+ filepath = os.path.join(folder, filename)
3364
+
3365
+ # Save the extracted image
3366
+ tifffile.imwrite(filepath, mip_image.astype(dtype))
3367
+
3368
+ rename_log.append({"Original File": file, "Renamed TIFF": filename})
3369
+
3370
+ ### **Process Leica LIF Files**
3371
+ elif ext == 'lif':
3372
+ try:
3373
+ lif_file = readlif.Reader(path)
3374
+
3375
+ for image_idx, image in enumerate(lif_file.getIterImage()):
3376
+ timepoints = range(getattr(image.dims, 't', 1))
3377
+ z_levels = range(getattr(image.dims, 'z', 1))
3378
+ channels = range(getattr(image.dims, 'c', 1))
3379
+
3380
+ for t_idx in timepoints:
3381
+ for c_idx in channels:
3382
+ z_stack = []
3383
+ for z_idx in z_levels:
3384
+ try:
3385
+ frame = image.getFrame(z=z_idx, t=t_idx, c=c_idx)
3386
+ z_stack.append(frame)
3387
+ except IndexError:
3388
+ print(f"Missing frame: T{t_idx}, Z{z_idx}, C{c_idx} in {file}, skipping frame.")
3389
+
3390
+ if z_stack:
3391
+ mip_image = np.max(np.stack(z_stack), axis=0)
3392
+ dtype = mip_image.dtype
3393
+ filename = f"{well}_T{t_idx+1:04d}F{image_idx+1:03d}L01C{c_idx+1:02d}.tif"
3394
+ filepath = os.path.join(folder, filename)
3395
+
3396
+ tifffile.imwrite(filepath, mip_image.astype(dtype))
3397
+ rename_log.append({"Original File": file, "Renamed TIFF": filename})
3398
+
3399
+ except Exception as e:
3400
+ print(f"Error processing LIF file {file}: {e}")
3401
+
3402
+ ### **Process Standard Image Files (TIFF, PNG, JPEG, BMP)**
3403
+ elif ext in ['tif', 'tiff', 'png', 'jpg', 'jpeg', 'bmp'] and not file.startswith("plate"):
3404
+ try:
3405
+ with tifffile.TiffFile(path) as tif:
3406
+ images = tif.asarray()
3407
+ ndim = images.ndim
3408
+
3409
+ # Defaults
3410
+ t_dim = z_dim = c_dim = 1
3411
+
3412
+ # Determine dimensions more explicitly
3413
+ if ndim == 2:
3414
+ mip_image = images
3415
+ filename = f"{well}_T0001F001L01C01.tif"
3416
+ tifffile.imwrite(os.path.join(folder, filename), mip_image)
3417
+ rename_log.append({"Original File": file, "Renamed TIFF": filename})
3418
+ continue
3419
+
3420
+ elif ndim == 3:
3421
+ if images.shape[0] <= 4: # Likely channels
3422
+ c_dim = images.shape[0]
3423
+ for c in range(c_dim):
3424
+ mip_image = images[c, :, :]
3425
+ filename = f"{well}_T0001F001L01C{c+1:02d}.tif"
3426
+ tifffile.imwrite(os.path.join(folder, filename), mip_image)
3427
+ rename_log.append({"Original File": file, "Renamed TIFF": filename})
3428
+ else: # Z-stack
3429
+ mip_image = np.max(images, axis=0)
3430
+ filename = f"{well}_T0001F001L01C01.tif"
3431
+ tifffile.imwrite(os.path.join(folder, filename), mip_image)
3432
+ rename_log.append({"Original File": file, "Renamed TIFF": filename})
3433
+
3434
+ elif ndim == 4:
3435
+ t_dim, z_dim, y_dim, x_dim = images.shape
3436
+ for t in range(t_dim):
3437
+ mip_image = np.max(images[t, :, :, :], axis=0)
3438
+ filename = f"{well}_T{t+1:04d}F001L01C01.tif"
3439
+ tifffile.imwrite(os.path.join(folder, filename), mip_image)
3440
+ rename_log.append({"Original File": file, "Renamed TIFF": filename})
3441
+
3442
+ else:
3443
+ raise ValueError(f"Unsupported TIFF dimensions: {images.shape}")
3444
+
3445
+ except Exception as e:
3446
+ print(f"Error processing standard image file {file}: {e}")
3447
+
3448
+
3449
+ # Save rename log as CSV
3450
+ pd.DataFrame(rename_log).to_csv(csv_path, index=False)
3451
+ print(f"Processing complete. Files saved in {folder} and rename log saved as {csv_path}.")
3452
+
3453
+ def apply_augmentation(image, method):
3454
+ if method == 'rotate90':
3455
+ return cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE)
3456
+ elif method == 'rotate180':
3457
+ return cv2.rotate(image, cv2.ROTATE_180)
3458
+ elif method == 'rotate270':
3459
+ return cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)
3460
+ elif method == 'flip_h':
3461
+ return cv2.flip(image, 1)
3462
+ elif method == 'flip_v':
3463
+ return cv2.flip(image, 0)
3464
+ return image
3465
+
3466
+ def process_instruction(entry):
3467
+ img = tifffile.imread(entry["src_img"])
3468
+ msk = tifffile.imread(entry["src_msk"])
3469
+ if entry["augment"]:
3470
+ img = apply_augmentation(img, entry["augment"])
3471
+ msk = apply_augmentation(msk, entry["augment"])
3472
+ tifffile.imwrite(entry["dst_img"], img)
3473
+ tifffile.imwrite(entry["dst_msk"], msk)
3474
+ return 1
3475
+
3476
+ def prepare_cellpose_dataset(input_root, augment_data=False, train_fraction=0.8, n_jobs=None):
3477
+
3478
+ from .utils import print_progress
3479
+
3480
+ time_ls = []
3481
+ input_root = os.path.abspath(input_root)
3482
+ output_root = os.path.join(input_root, "cellpose_dataset")
3483
+
3484
+ def get_augmentations():
3485
+ return ['rotate90', 'rotate180', 'rotate270', 'flip_h', 'flip_v']
3486
+
3487
+ def find_image_mask_pairs(dataset_path):
3488
+ mask_dir = os.path.join(dataset_path, "masks")
3489
+ pairs = []
3490
+ for fname in os.listdir(dataset_path):
3491
+ if fname.lower().endswith((".tif", ".tiff")):
3492
+ img_path = os.path.join(dataset_path, fname)
3493
+ msk_path = os.path.join(mask_dir, fname)
3494
+ if os.path.isfile(msk_path):
3495
+ pairs.append((img_path, msk_path))
3496
+ return pairs
3497
+
3498
+ def prepare_output_folders(base):
3499
+ for subset in ["train", "test"]:
3500
+ os.makedirs(os.path.join(base, subset, "images"), exist_ok=True)
3501
+ os.makedirs(os.path.join(base, subset, "masks"), exist_ok=True)
3502
+
3503
+ print("Scanning datasets...")
3504
+ datasets = []
3505
+ for subdir in os.listdir(input_root):
3506
+ dataset_path = os.path.join(input_root, subdir)
3507
+ if os.path.isdir(dataset_path) and os.path.isdir(os.path.join(dataset_path, "masks")):
3508
+ pairs = find_image_mask_pairs(dataset_path)
3509
+ if pairs:
3510
+ datasets.append(pairs)
3511
+ print(f" Found {len(pairs)} images in {dataset_path}")
3512
+
3513
+ if not datasets:
3514
+ raise ValueError("No valid datasets with images and masks found.")
3515
+
3516
+ prepare_output_folders(output_root)
3517
+
3518
+ min_size = min(len(pairs) for pairs in datasets)
3519
+ target_size = min_size if not augment_data else max(len(pairs) for pairs in datasets)
3520
+
3521
+ print("\nPreparing instruction list...")
3522
+ instructions = []
3523
+ global_index = 0
3524
+
3525
+ for pairs in datasets:
3526
+ dataset_len = len(pairs)
3527
+
3528
+ # --- Step 1: Sample or augment ---
3529
+ sampled_pairs = []
3530
+ if dataset_len >= target_size:
3531
+ sampled_pairs = random.sample(pairs, target_size)
3532
+ else:
3533
+ sampled_pairs = pairs.copy()
3534
+ if augment_data:
3535
+ needed = target_size - dataset_len
3536
+ aug_methods = get_augmentations()
3537
+ full_loops = needed // len(aug_methods)
3538
+ extra = needed % len(aug_methods)
3539
+
3540
+ for _ in range(full_loops):
3541
+ for (img_path, msk_path), aug in zip(pairs, aug_methods * (dataset_len // len(aug_methods))):
3542
+ sampled_pairs.append((img_path, msk_path, aug))
3543
+ if extra > 0:
3544
+ subset = random.sample(pairs * ((extra // len(aug_methods)) + 1), extra)
3545
+ for (img_path, msk_path), aug in zip(subset, aug_methods[:extra]):
3546
+ sampled_pairs.append((img_path, msk_path, aug))
3547
+
3548
+ # Add "no augmentation" tag to original files
3549
+ augmented_sampled = [
3550
+ (tup[0], tup[1], None) if len(tup) == 2 else tup
3551
+ for tup in sampled_pairs
3552
+ ]
3553
+
3554
+ # --- Step 2: Split into train/test ---
3555
+ random.shuffle(augmented_sampled)
3556
+ split_idx = int(train_fraction * len(augmented_sampled))
3557
+ split_sets = {
3558
+ "train": augmented_sampled[:split_idx],
3559
+ "test": augmented_sampled[split_idx:]
3560
+ }
3561
+
3562
+ for subset, items in split_sets.items():
3563
+ for img_path, msk_path, aug in items:
3564
+ dst_img = os.path.join(output_root, subset, "images", f"{global_index:05d}.tif")
3565
+ dst_msk = os.path.join(output_root, subset, "masks", f"{global_index:05d}.tif")
3566
+ instructions.append({
3567
+ "src_img": img_path,
3568
+ "src_msk": msk_path,
3569
+ "dst_img": dst_img,
3570
+ "dst_msk": dst_msk,
3571
+ "augment": aug
3572
+ })
3573
+ global_index += 1
3574
+
3575
+ print(f"Total files to process: {len(instructions)}")
3576
+
3577
+ # --- Step 3: Process with multiprocessing ---
3578
+ print("Processing images with multiprocessing...")
3579
+
3580
+ if n_jobs is None:
3581
+ n_jobs = max(1, cpu_count() - 1)
3582
+ else:
3583
+ n_jobs = int(n_jobs)
3584
+
3585
+ with Pool(n_jobs) as pool:
3586
+ for i, _ in enumerate(pool.imap_unordered(process_instruction, instructions), 1):
3587
+ print_progress(i, len(instructions), n_jobs=n_jobs, time_ls=time_ls, batch_size=None, operation_type="cellpose dataset")
3588
+
3589
+ print(f"Done. Dataset saved to: {output_root}")
3590
+