spacr 0.4.15__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. spacr/__init__.py +2 -2
  2. spacr/core.py +52 -10
  3. spacr/deep_spacr.py +2 -3
  4. spacr/gui.py +0 -1
  5. spacr/gui_core.py +247 -41
  6. spacr/gui_elements.py +133 -2
  7. spacr/gui_utils.py +22 -17
  8. spacr/io.py +624 -149
  9. spacr/ml.py +141 -258
  10. spacr/plot.py +76 -34
  11. spacr/resources/MEDIAR/__pycache__/SetupDict.cpython-39.pyc +0 -0
  12. spacr/resources/MEDIAR/__pycache__/evaluate.cpython-39.pyc +0 -0
  13. spacr/resources/MEDIAR/__pycache__/generate_mapping.cpython-39.pyc +0 -0
  14. spacr/resources/MEDIAR/__pycache__/main.cpython-39.pyc +0 -0
  15. spacr/resources/MEDIAR/core/Baseline/__pycache__/Predictor.cpython-39.pyc +0 -0
  16. spacr/resources/MEDIAR/core/Baseline/__pycache__/Trainer.cpython-39.pyc +0 -0
  17. spacr/resources/MEDIAR/core/Baseline/__pycache__/__init__.cpython-39.pyc +0 -0
  18. spacr/resources/MEDIAR/core/Baseline/__pycache__/utils.cpython-39.pyc +0 -0
  19. spacr/resources/MEDIAR/core/MEDIAR/__pycache__/EnsemblePredictor.cpython-39.pyc +0 -0
  20. spacr/resources/MEDIAR/core/MEDIAR/__pycache__/Predictor.cpython-39.pyc +0 -0
  21. spacr/resources/MEDIAR/core/MEDIAR/__pycache__/Trainer.cpython-39.pyc +0 -0
  22. spacr/resources/MEDIAR/core/MEDIAR/__pycache__/__init__.cpython-39.pyc +0 -0
  23. spacr/resources/MEDIAR/core/MEDIAR/__pycache__/utils.cpython-39.pyc +0 -0
  24. spacr/resources/MEDIAR/core/__pycache__/BasePredictor.cpython-39.pyc +0 -0
  25. spacr/resources/MEDIAR/core/__pycache__/BaseTrainer.cpython-39.pyc +0 -0
  26. spacr/resources/MEDIAR/core/__pycache__/__init__.cpython-39.pyc +0 -0
  27. spacr/resources/MEDIAR/core/__pycache__/utils.cpython-39.pyc +0 -0
  28. spacr/resources/MEDIAR/train_tools/__pycache__/__init__.cpython-39.pyc +0 -0
  29. spacr/resources/MEDIAR/train_tools/__pycache__/measures.cpython-39.pyc +0 -0
  30. spacr/resources/MEDIAR/train_tools/__pycache__/utils.cpython-39.pyc +0 -0
  31. spacr/resources/MEDIAR/train_tools/data_utils/__pycache__/__init__.cpython-39.pyc +0 -0
  32. spacr/resources/MEDIAR/train_tools/data_utils/__pycache__/datasetter.cpython-39.pyc +0 -0
  33. spacr/resources/MEDIAR/train_tools/data_utils/__pycache__/transforms.cpython-39.pyc +0 -0
  34. spacr/resources/MEDIAR/train_tools/data_utils/__pycache__/utils.cpython-39.pyc +0 -0
  35. spacr/resources/MEDIAR/train_tools/data_utils/custom/__pycache__/CellAware.cpython-39.pyc +0 -0
  36. spacr/resources/MEDIAR/train_tools/data_utils/custom/__pycache__/LoadImage.cpython-39.pyc +0 -0
  37. spacr/resources/MEDIAR/train_tools/data_utils/custom/__pycache__/NormalizeImage.cpython-39.pyc +0 -0
  38. spacr/resources/MEDIAR/train_tools/data_utils/custom/__pycache__/__init__.cpython-39.pyc +0 -0
  39. spacr/resources/MEDIAR/train_tools/models/__pycache__/MEDIARFormer.cpython-39.pyc +0 -0
  40. spacr/resources/MEDIAR/train_tools/models/__pycache__/__init__.cpython-39.pyc +0 -0
  41. spacr/sequencing.py +73 -38
  42. spacr/settings.py +161 -135
  43. spacr/submodules.py +618 -215
  44. spacr/timelapse.py +197 -29
  45. spacr/toxo.py +23 -23
  46. spacr/utils.py +186 -128
  47. {spacr-0.4.15.dist-info → spacr-0.5.0.dist-info}/METADATA +5 -2
  48. {spacr-0.4.15.dist-info → spacr-0.5.0.dist-info}/RECORD +53 -24
  49. spacr/stats.py +0 -221
  50. /spacr/{cellpose.py → spacr_cellpose.py} +0 -0
  51. {spacr-0.4.15.dist-info → spacr-0.5.0.dist-info}/LICENSE +0 -0
  52. {spacr-0.4.15.dist-info → spacr-0.5.0.dist-info}/WHEEL +0 -0
  53. {spacr-0.4.15.dist-info → spacr-0.5.0.dist-info}/entry_points.txt +0 -0
  54. {spacr-0.4.15.dist-info → spacr-0.5.0.dist-info}/top_level.txt +0 -0
spacr/timelapse.py CHANGED
@@ -7,9 +7,9 @@ from IPython.display import display
7
7
  from IPython.display import Image as ipyimage
8
8
  import trackpy as tp
9
9
  from btrack import datasets as btrack_datasets
10
- from skimage.measure import regionprops
10
+ from skimage.measure import regionprops, regionprops_table
11
11
  from scipy.signal import find_peaks
12
- from scipy.optimize import curve_fit
12
+ from scipy.optimize import curve_fit, linear_sum_assignment
13
13
  from scipy.integrate import trapz
14
14
  import matplotlib.pyplot as plt
15
15
 
@@ -255,7 +255,7 @@ def _relabel_masks_based_on_tracks(masks, tracks, mode='btrack'):
255
255
 
256
256
  return relabeled_masks
257
257
 
258
- def _prepare_for_tracking(mask_array):
258
+ def _prepare_for_tracking_v1(mask_array):
259
259
  """
260
260
  Prepare the mask array for object tracking.
261
261
 
@@ -286,6 +286,87 @@ def _prepare_for_tracking(mask_array):
286
286
  })
287
287
  return pd.DataFrame(frames)
288
288
 
289
+ def _prepare_for_tracking(mask_array):
290
+ frames = []
291
+ for t, frame in enumerate(mask_array):
292
+ props = regionprops_table(
293
+ frame,
294
+ properties=('label', 'centroid-0', 'centroid-1', 'area',
295
+ 'bbox-0', 'bbox-1', 'bbox-2', 'bbox-3',
296
+ 'eccentricity')
297
+ )
298
+ df = pd.DataFrame(props)
299
+ df = df.rename(columns={
300
+ 'centroid-0': 'y', 'centroid-1': 'x', 'area': 'mass',
301
+ 'label': 'original_label'
302
+ })
303
+ df['frame'] = t
304
+ frames.append(df[['frame','y','x','mass','original_label',
305
+ 'bbox-0','bbox-1','bbox-2','bbox-3','eccentricity']])
306
+ return pd.concat(frames, ignore_index=True)
307
+
308
+
309
+ def _track_by_iou(masks, iou_threshold=0.1):
310
+ """
311
+ Build a track table by linking masks frame→frame via IoU.
312
+ Returns a DataFrame with columns [frame, original_label, track_id].
313
+ """
314
+ n_frames = masks.shape[0]
315
+ # 1) initialize: every label in frame 0 starts its own track
316
+ labels0 = np.unique(masks[0])[1:]
317
+ next_track = 1
318
+ track_map = {} # (frame,label) -> track_id
319
+ for L in labels0:
320
+ track_map[(0, L)] = next_track
321
+ next_track += 1
322
+
323
+ # 2) iterate through frames
324
+ for t in range(1, n_frames):
325
+ prev, curr = masks[t-1], masks[t]
326
+ matches = link_by_iou(prev, curr, iou_threshold=iou_threshold)
327
+ used_curr = set()
328
+ # a) assign matched labels to existing tracks
329
+ for L_prev, L_curr in matches:
330
+ tid = track_map[(t-1, L_prev)]
331
+ track_map[(t, L_curr)] = tid
332
+ used_curr.add(L_curr)
333
+ # b) any label in curr not matched → new track
334
+ for L in np.unique(curr)[1:]:
335
+ if L not in used_curr:
336
+ track_map[(t, L)] = next_track
337
+ next_track += 1
338
+
339
+ # 3) flatten into DataFrame
340
+ records = []
341
+ for (frame, label), tid in track_map.items():
342
+ records.append({'frame': frame, 'original_label': label, 'track_id': tid})
343
+ return pd.DataFrame(records)
344
+
345
+ def link_by_iou(mask_prev, mask_next, iou_threshold=0.1):
346
+ # Get labels
347
+ labels_prev = np.unique(mask_prev)[1:]
348
+ labels_next = np.unique(mask_next)[1:]
349
+ # Precompute masks as boolean
350
+ bool_prev = {L: mask_prev==L for L in labels_prev}
351
+ bool_next = {L: mask_next==L for L in labels_next}
352
+ # Cost matrix = 1 - IoU
353
+ cost = np.ones((len(labels_prev), len(labels_next)), dtype=float)
354
+ for i, L1 in enumerate(labels_prev):
355
+ m1 = bool_prev[L1]
356
+ for j, L2 in enumerate(labels_next):
357
+ m2 = bool_next[L2]
358
+ inter = np.logical_and(m1, m2).sum()
359
+ union = np.logical_or(m1, m2).sum()
360
+ if union > 0:
361
+ cost[i, j] = 1 - inter/union
362
+ # Solve assignment
363
+ row_ind, col_ind = linear_sum_assignment(cost)
364
+ matches = []
365
+ for i, j in zip(row_ind, col_ind):
366
+ if cost[i,j] <= 1 - iou_threshold:
367
+ matches.append((labels_prev[i], labels_next[j]))
368
+ return matches
369
+
289
370
  def _find_optimal_search_range(features, initial_search_range=500, increment=10, max_attempts=49, memory=3):
290
371
  """
291
372
  Find the optimal search range for linking features.
@@ -336,7 +417,94 @@ def _remove_objects_from_first_frame(masks, percentage=10):
336
417
  masks[0][first_frame == label] = 0
337
418
  return masks
338
419
 
339
- def _facilitate_trackin_with_adaptive_removal(masks, search_range=500, max_attempts=100, memory=3):
420
+ def _track_by_iou(masks, iou_threshold=0.1):
421
+ """
422
+ Build a track table by linking masks frame→frame via IoU.
423
+ Returns a DataFrame with columns [frame, original_label, track_id].
424
+ """
425
+ n_frames = masks.shape[0]
426
+ # 1) initialize: every label in frame 0 starts its own track
427
+ labels0 = np.unique(masks[0])[1:]
428
+ next_track = 1
429
+ track_map = {} # (frame,label) -> track_id
430
+ for L in labels0:
431
+ track_map[(0, L)] = next_track
432
+ next_track += 1
433
+
434
+ # 2) iterate through frames
435
+ for t in range(1, n_frames):
436
+ prev, curr = masks[t-1], masks[t]
437
+ matches = link_by_iou(prev, curr, iou_threshold=iou_threshold)
438
+ used_curr = set()
439
+ # a) assign matched labels to existing tracks
440
+ for L_prev, L_curr in matches:
441
+ tid = track_map[(t-1, L_prev)]
442
+ track_map[(t, L_curr)] = tid
443
+ used_curr.add(L_curr)
444
+ # b) any label in curr not matched → new track
445
+ for L in np.unique(curr)[1:]:
446
+ if L not in used_curr:
447
+ track_map[(t, L)] = next_track
448
+ next_track += 1
449
+
450
+ # 3) flatten into DataFrame
451
+ records = []
452
+ for (frame, label), tid in track_map.items():
453
+ records.append({'frame': frame, 'original_label': label, 'track_id': tid})
454
+ return pd.DataFrame(records)
455
+
456
+
457
+ def _facilitate_trackin_with_adaptive_removal(masks, search_range=None, max_attempts=5, memory=3, min_mass=50, track_by_iou=False):
458
+ """
459
+ Facilitates object tracking with deterministic initial filtering and
460
+ trackpy’s constant-velocity prediction.
461
+
462
+ Args:
463
+ masks (np.ndarray): integer‐labeled masks (frames × H × W).
464
+ search_range (int|None): max displacement; if None, auto‐computed.
465
+ max_attempts (int): how many times to retry with smaller search_range.
466
+ memory (int): trackpy memory parameter.
467
+ min_mass (float): drop any object in frame 0 with area < min_mass.
468
+
469
+ Returns:
470
+ masks, features_df, tracks_df
471
+
472
+ Raises:
473
+ RuntimeError if linking fails after max_attempts.
474
+ """
475
+ # 1) initial features & filter frame 0 by area
476
+ features = _prepare_for_tracking(masks)
477
+ f0 = features[features['frame'] == 0]
478
+ valid = f0.loc[f0['mass'] >= min_mass, 'original_label'].unique()
479
+ masks[0] = np.where(np.isin(masks[0], valid), masks[0], 0)
480
+
481
+ # 2) recompute features on filtered masks
482
+ features = _prepare_for_tracking(masks)
483
+
484
+ # 3) default search_range = 2×sqrt(99th‑pct area)
485
+ if search_range is None:
486
+ a99 = f0['mass'].quantile(0.99)
487
+ search_range = max(1, int(2 * np.sqrt(a99)))
488
+
489
+ # 4) attempt linking, shrinking search_range on failure
490
+ for attempt in range(1, max_attempts + 1):
491
+ try:
492
+ if track_by_iou:
493
+ tracks_df = _track_by_iou(masks, iou_threshold=0.1)
494
+ else:
495
+ tracks_df = tp.link_df(features,search_range=search_range, memory=memory, predict=True)
496
+ print(f"Linked on attempt {attempt} with search_range={search_range}")
497
+ return masks, features, tracks_df
498
+
499
+ except Exception as e:
500
+ search_range = max(1, int(search_range * 0.8))
501
+ print(f"Attempt {attempt} failed ({e}); reducing search_range to {search_range}")
502
+
503
+ raise RuntimeError(
504
+ f"Failed to track after {max_attempts} attempts; last search_range={search_range}"
505
+ )
506
+
507
+ def _facilitate_trackin_with_adaptive_removal_v1(masks, search_range=500, max_attempts=100, memory=3):
340
508
  """
341
509
  Facilitates object tracking with adaptive removal.
342
510
 
@@ -533,14 +701,14 @@ def exponential_decay(x, a, b, c):
533
701
 
534
702
  def preprocess_pathogen_data(pathogen_df):
535
703
  # Group by identifiers and count the number of parasites
536
- parasite_counts = pathogen_df.groupby(['plate', 'row_name', 'column_name', 'field', 'timeid', 'pathogen_cell_id']).size().reset_index(name='parasite_count')
704
+ parasite_counts = pathogen_df.groupby(['plateID', 'rowID', 'column_name', 'fieldID', 'timeid', 'pathogen_cell_id']).size().reset_index(name='parasite_count')
537
705
 
538
706
  # Aggregate numerical columns and take the first of object columns
539
- agg_funcs = {col: 'mean' if np.issubdtype(pathogen_df[col].dtype, np.number) else 'first' for col in pathogen_df.columns if col not in ['plate', 'row_name', 'column_name', 'field', 'timeid', 'pathogen_cell_id', 'parasite_count']}
540
- pathogen_agg = pathogen_df.groupby(['plate', 'row_name', 'column_name', 'field', 'timeid', 'pathogen_cell_id']).agg(agg_funcs).reset_index()
707
+ agg_funcs = {col: 'mean' if np.issubdtype(pathogen_df[col].dtype, np.number) else 'first' for col in pathogen_df.columns if col not in ['plateID', 'rowID', 'column_name', 'fieldID', 'timeid', 'pathogen_cell_id', 'parasite_count']}
708
+ pathogen_agg = pathogen_df.groupby(['plateID', 'rowID', 'column_name', 'fieldID', 'timeid', 'pathogen_cell_id']).agg(agg_funcs).reset_index()
541
709
 
542
710
  # Merge the counts back into the aggregated data
543
- pathogen_agg = pathogen_agg.merge(parasite_counts, on=['plate', 'row_name', 'column_name', 'field', 'timeid', 'pathogen_cell_id'])
711
+ pathogen_agg = pathogen_agg.merge(parasite_counts, on=['plateID', 'rowID', 'column_name', 'fieldID', 'timeid', 'pathogen_cell_id'])
544
712
 
545
713
  # Remove the object_label column as it corresponds to the pathogen ID not the cell ID
546
714
  if 'object_label' in pathogen_agg.columns:
@@ -604,10 +772,10 @@ def save_results_dataframe(df, src, results_name):
604
772
  def summarize_per_well(peak_details_df):
605
773
  # Step 1: Split the 'ID' column
606
774
  split_columns = peak_details_df['ID'].str.split('_', expand=True)
607
- peak_details_df[['plate', 'row_name', 'column', 'field', 'object_number']] = split_columns
775
+ peak_details_df[['plateID', 'rowID', 'columnID', 'fieldID', 'object_number']] = split_columns
608
776
 
609
- # Step 2: Create 'well_ID' by combining 'row_name' and 'column'
610
- peak_details_df['well_ID'] = peak_details_df['row_name'] + '_' + peak_details_df['column']
777
+ # Step 2: Create 'well_ID' by combining 'rowID' and 'columnID'
778
+ peak_details_df['well_ID'] = peak_details_df['rowID'] + '_' + peak_details_df['columnID']
611
779
 
612
780
  # Filter entries where 'amplitude' is not null
613
781
  filtered_df = peak_details_df[peak_details_df['amplitude'].notna()]
@@ -635,10 +803,10 @@ def summarize_per_well(peak_details_df):
635
803
  def summarize_per_well_inf_non_inf(peak_details_df):
636
804
  # Step 1: Split the 'ID' column
637
805
  split_columns = peak_details_df['ID'].str.split('_', expand=True)
638
- peak_details_df[['plate', 'row_name', 'column', 'field', 'object_number']] = split_columns
806
+ peak_details_df[['plateID', 'rowID', 'columnID', 'fieldID', 'object_number']] = split_columns
639
807
 
640
- # Step 2: Create 'well_ID' by combining 'row_name' and 'column'
641
- peak_details_df['well_ID'] = peak_details_df['row_name'] + '_' + peak_details_df['column']
808
+ # Step 2: Create 'well_ID' by combining 'rowID' and 'columnID'
809
+ peak_details_df['well_ID'] = peak_details_df['rowID'] + '_' + peak_details_df['columnID']
642
810
 
643
811
  # Assume 'pathogen_count' indicates infection if > 0
644
812
  # Add an 'infected_status' column to classify cells
@@ -669,7 +837,7 @@ def analyze_calcium_oscillations(db_loc, measurement='cell_channel_1_mean_intens
669
837
  pathogen_df = pd.read_sql("SELECT * FROM pathogen", conn)
670
838
  pathogen_df['pathogen_cell_id'] = pathogen_df['pathogen_cell_id'].astype(float).astype('Int64')
671
839
  pathogen_df = preprocess_pathogen_data(pathogen_df)
672
- cell_df = cell_df.merge(pathogen_df, on=['plate', 'row_name', 'column_name', 'field', 'timeid', 'object_label'], how='left', suffixes=('', '_pathogen'))
840
+ cell_df = cell_df.merge(pathogen_df, on=['plateID', 'rowID', 'column_name', 'fieldID', 'timeid', 'object_label'], how='left', suffixes=('', '_pathogen'))
673
841
  cell_df['parasite_count'] = cell_df['parasite_count'].fillna(0)
674
842
  print(f'After pathogen merge: {len(cell_df)} objects')
675
843
 
@@ -677,7 +845,7 @@ def analyze_calcium_oscillations(db_loc, measurement='cell_channel_1_mean_intens
677
845
  if cytoplasm:
678
846
  cytoplasm_df = pd.read_sql(f"SELECT * FROM {'cytoplasm'}", conn)
679
847
  # Merge on specified columns
680
- cell_df = cell_df.merge(cytoplasm_df, on=['plate', 'row_name', 'column_name', 'field', 'timeid', 'object_label'], how='left', suffixes=('', '_cytoplasm'))
848
+ cell_df = cell_df.merge(cytoplasm_df, on=['plateID', 'rowID', 'column_name', 'fieldID', 'timeid', 'object_label'], how='left', suffixes=('', '_cytoplasm'))
681
849
 
682
850
  print(f'After cytoplasm merge: {len(cell_df)} objects')
683
851
 
@@ -686,13 +854,13 @@ def analyze_calcium_oscillations(db_loc, measurement='cell_channel_1_mean_intens
686
854
  # Continue with your existing processing on cell_df now containing merged data...
687
855
  # Prepare DataFrame (use cell_df instead of df)
688
856
  prcf_components = cell_df['prcf'].str.split('_', expand=True)
689
- cell_df['plate'] = prcf_components[0]
690
- cell_df['row_name'] = prcf_components[1]
691
- cell_df['column'] = prcf_components[2]
692
- cell_df['field'] = prcf_components[3]
857
+ cell_df['plateID'] = prcf_components[0]
858
+ cell_df['rowID'] = prcf_components[1]
859
+ cell_df['columnID'] = prcf_components[2]
860
+ cell_df['fieldID'] = prcf_components[3]
693
861
  cell_df['time'] = prcf_components[4].str.extract('t(\d+)').astype(int)
694
862
  cell_df['object_number'] = cell_df['object_label']
695
- cell_df['plate_row_column_field_object'] = cell_df['plate'].astype(str) + '_' + cell_df['row_name'].astype(str) + '_' + cell_df['column'].astype(str) + '_' + cell_df['field'].astype(str) + '_' + cell_df['object_label'].astype(str)
863
+ cell_df['plate_row_column_field_object'] = cell_df['plateID'].astype(str) + '_' + cell_df['rowID'].astype(str) + '_' + cell_df['columnID'].astype(str) + '_' + cell_df['fieldID'].astype(str) + '_' + cell_df['object_label'].astype(str)
696
864
 
697
865
  df = cell_df.copy()
698
866
 
@@ -752,10 +920,10 @@ def analyze_calcium_oscillations(db_loc, measurement='cell_channel_1_mean_intens
752
920
  if len(peaks) == 0:
753
921
  peak_details_list.append({
754
922
  'ID': unique_id,
755
- 'plate': group['plate'].iloc[0],
756
- 'row_name': group['row_name'].iloc[0],
757
- 'column': group['column'].iloc[0],
758
- 'field': group['field'].iloc[0],
923
+ 'plateID': group['plateID'].iloc[0],
924
+ 'rowID': group['rowID'].iloc[0],
925
+ 'columnID': group['columnID'].iloc[0],
926
+ 'fieldID': group['fieldID'].iloc[0],
759
927
  'object_number': group['object_number'].iloc[0],
760
928
  'time': np.nan, # The time of the peak
761
929
  'amplitude': np.nan,
@@ -783,10 +951,10 @@ def analyze_calcium_oscillations(db_loc, measurement='cell_channel_1_mean_intens
783
951
 
784
952
  peak_details_list.append({
785
953
  'ID': unique_id,
786
- 'plate': group['plate'].iloc[0],
787
- 'row_name': group['row_name'].iloc[0],
788
- 'column': group['column'].iloc[0],
789
- 'field': group['field'].iloc[0],
954
+ 'plateID': group['plateID'].iloc[0],
955
+ 'rowID': group['rowID'].iloc[0],
956
+ 'columnID': group['columnID'].iloc[0],
957
+ 'fieldID': group['fieldID'].iloc[0],
790
958
  'object_number': group['object_number'].iloc[0],
791
959
  'time': peak_time, # The time of the peak
792
960
  'amplitude': amplitude,
spacr/toxo.py CHANGED
@@ -494,25 +494,25 @@ def generate_score_heatmap(settings):
494
494
  if 'column_name' in df.columns:
495
495
  df = df[df['column_name']==column]
496
496
  elif 'column' in df.columns:
497
- df['column_name'] = df['column']
497
+ df['columnID'] = df['column']
498
498
  df = df[df['column_name']==column]
499
499
  if not plate is None:
500
- df['plate'] = f"plate{plate}"
501
- grouped_df = df.groupby(['plate', 'row_name', 'column_name'])[data_column].mean().reset_index()
502
- grouped_df['prc'] = grouped_df['plate'].astype(str) + '_' + grouped_df['row_name'].astype(str) + '_' + grouped_df['column_name'].astype(str)
500
+ df['plateID'] = f"plate{plate}"
501
+ grouped_df = df.groupby(['plateID', 'rowID', 'column_name'])[data_column].mean().reset_index()
502
+ grouped_df['prc'] = grouped_df['plateID'].astype(str) + '_' + grouped_df['rowID'].astype(str) + '_' + grouped_df['column_name'].astype(str)
503
503
  return grouped_df
504
504
 
505
505
  def calculate_fraction_mixed_condition(csv, plate=1, column='c3', control_sgrnas = ['TGGT1_220950_1', 'TGGT1_233460_4']):
506
506
  df = pd.read_csv(csv)
507
507
  df = df[df['column_name']==column]
508
508
  if plate not in df.columns:
509
- df['plate'] = f"plate{plate}"
509
+ df['plateID'] = f"plate{plate}"
510
510
  df = df[df['grna_name'].str.match(f'^{control_sgrnas[0]}$|^{control_sgrnas[1]}$')]
511
- grouped_df = df.groupby(['plate', 'row_name', 'column_name'])['count'].sum().reset_index()
511
+ grouped_df = df.groupby(['plateID', 'rowID', 'column_name'])['count'].sum().reset_index()
512
512
  grouped_df = grouped_df.rename(columns={'count': 'total_count'})
513
- merged_df = pd.merge(df, grouped_df, on=['plate', 'row_name', 'column_name'])
513
+ merged_df = pd.merge(df, grouped_df, on=['plateID', 'rowID', 'column_name'])
514
514
  merged_df['fraction'] = merged_df['count'] / merged_df['total_count']
515
- merged_df['prc'] = merged_df['plate'].astype(str) + '_' + merged_df['row_name'].astype(str) + '_' + merged_df['column_name'].astype(str)
515
+ merged_df['prc'] = merged_df['plateID'].astype(str) + '_' + merged_df['rowID'].astype(str) + '_' + merged_df['column_name'].astype(str)
516
516
  return merged_df
517
517
 
518
518
  def plot_multi_channel_heatmap(df, column='c3'):
@@ -524,17 +524,17 @@ def generate_score_heatmap(settings):
524
524
  - column: Column to filter by (default is 'c3').
525
525
  """
526
526
  # Extract row number and convert to integer for sorting
527
- df['row_num'] = df['row_name'].str.extract(r'(\d+)').astype(int)
527
+ df['row_num'] = df['rowID'].str.extract(r'(\d+)').astype(int)
528
528
 
529
529
  # Filter and sort by plate, row, and column
530
530
  df = df[df['column_name'] == column]
531
- df = df.sort_values(by=['plate', 'row_num', 'column_name'])
531
+ df = df.sort_values(by=['plateID', 'row_num', 'column_name'])
532
532
 
533
533
  # Drop temporary 'row_num' column after sorting
534
534
  df = df.drop('row_num', axis=1)
535
535
 
536
536
  # Create a new column combining plate, row, and column for the index
537
- df['plate_row_col'] = df['plate'] + '-' + df['row_name'] + '-' + df['column_name']
537
+ df['plate_row_col'] = df['plateID'] + '-' + df['rowID'] + '-' + df['column_name']
538
538
 
539
539
  # Set 'plate_row_col' as the index
540
540
  df.set_index('plate_row_col', inplace=True)
@@ -593,9 +593,9 @@ def generate_score_heatmap(settings):
593
593
  df = pd.read_csv(csv_file) # Read CSV into DataFrame
594
594
  df = df[df['column_name']==column]
595
595
  if not plate is None:
596
- df['plate'] = f"plate{plate}"
597
- # Group the data by 'plate', 'row_name', and 'column_name'
598
- grouped_df = df.groupby(['plate', 'row_name', 'column_name'])[data_column].mean().reset_index()
596
+ df['plateID'] = f"plate{plate}"
597
+ # Group the data by 'plateID', 'rowID', and 'column_name'
598
+ grouped_df = df.groupby(['plateID', 'rowID', 'column_name'])[data_column].mean().reset_index()
599
599
  # Use the CSV filename to create a new column name
600
600
  folder_name = os.path.dirname(csv_file).replace(".csv", "")
601
601
  new_column_name = os.path.basename(f"{folder_name}_{data_column}")
@@ -606,8 +606,8 @@ def generate_score_heatmap(settings):
606
606
  if combined_df is None:
607
607
  combined_df = grouped_df
608
608
  else:
609
- combined_df = pd.merge(combined_df, grouped_df, on=['plate', 'row_name', 'column_name'], how='outer')
610
- combined_df['prc'] = combined_df['plate'].astype(str) + '_' + combined_df['row_name'].astype(str) + '_' + combined_df['column_name'].astype(str)
609
+ combined_df = pd.merge(combined_df, grouped_df, on=['plateID', 'rowID', 'column_name'], how='outer')
610
+ combined_df['prc'] = combined_df['plateID'].astype(str) + '_' + combined_df['rowID'].astype(str) + '_' + combined_df['column_name'].astype(str)
611
611
  return combined_df
612
612
 
613
613
  def calculate_mae(df):
@@ -629,16 +629,16 @@ def generate_score_heatmap(settings):
629
629
  mae_df = pd.DataFrame(mae_data)
630
630
  return mae_df
631
631
 
632
- result_df = combine_classification_scores(settings['folders'], settings['csv_name'], settings['data_column'], settings['plate'], settings['column'], )
633
- df = calculate_fraction_mixed_condition(settings['csv'], settings['plate'], settings['column'], settings['control_sgrnas'])
632
+ result_df = combine_classification_scores(settings['folders'], settings['csv_name'], settings['data_column'], settings['plateID'], settings['columnID'], )
633
+ df = calculate_fraction_mixed_condition(settings['csv'], settings['plateID'], settings['columnID'], settings['control_sgrnas'])
634
634
  df = df[df['grna_name']==settings['fraction_grna']]
635
635
  fraction_df = df[['fraction', 'prc']]
636
636
  merged_df = pd.merge(fraction_df, result_df, on=['prc'])
637
- cv_df = group_cv_score(settings['cv_csv'], settings['plate'], settings['column'], settings['data_column_cv'])
637
+ cv_df = group_cv_score(settings['cv_csv'], settings['plateID'], settings['columnID'], settings['data_column_cv'])
638
638
  cv_df = cv_df[[settings['data_column_cv'], 'prc']]
639
639
  merged_df = pd.merge(merged_df, cv_df, on=['prc'])
640
640
 
641
- fig = plot_multi_channel_heatmap(merged_df, settings['column'])
641
+ fig = plot_multi_channel_heatmap(merged_df, settings['columnID'])
642
642
  if 'row_number' in merged_df.columns:
643
643
  merged_df = merged_df.drop('row_num', axis=1)
644
644
  mae_df = calculate_mae(merged_df)
@@ -646,9 +646,9 @@ def generate_score_heatmap(settings):
646
646
  mae_df = mae_df.drop('row_num', axis=1)
647
647
 
648
648
  if not settings['dst'] is None:
649
- mae_dst = os.path.join(settings['dst'], f"mae_scores_comparison_plate_{settings['plate']}.csv")
650
- merged_dst = os.path.join(settings['dst'], f"scores_comparison_plate_{settings['plate']}_data.csv")
651
- heatmap_save = os.path.join(settings['dst'], f"scores_comparison_plate_{settings['plate']}.pdf")
649
+ mae_dst = os.path.join(settings['dst'], f"mae_scores_comparison_plate_{settings['plateID']}.csv")
650
+ merged_dst = os.path.join(settings['dst'], f"scores_comparison_plate_{settings['plateID']}_data.csv")
651
+ heatmap_save = os.path.join(settings['dst'], f"scores_comparison_plate_{settings['plateID']}.pdf")
652
652
  mae_df.to_csv(mae_dst, index=False)
653
653
  merged_df.to_csv(merged_dst, index=False)
654
654
  fig.savefig(heatmap_save, format='pdf', dpi=600, bbox_inches='tight')