spacr 0.3.81__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/ml.py CHANGED
@@ -337,7 +337,140 @@ def minimum_cell_simulation(settings, num_repeats=10, sample_size=100, tolerance
337
337
  df = pd.read_csv(score_data)
338
338
  df = correct_metadata_column_names(df)
339
339
  df['plate'] = f'plate{i + 1}'
340
- df['prc'] = df['plate'] + '_' + df['row_name'].astype(str) + '_' + df['column'].astype(str)
340
+
341
+ if 'prc' not in df.columns:
342
+ df['prc'] = df['plate'] + '_' + df['row'].astype(str) + '_' + df['column'].astype(str)
343
+
344
+ dfs.append(df)
345
+
346
+ df = pd.concat(dfs, axis=0)
347
+
348
+ # Compute the number of cells per well and select the top 100 wells by cell count
349
+ cell_counts = df.groupby('prc').size().reset_index(name='cell_count')
350
+ top_wells = cell_counts.nlargest(sample_size, 'cell_count')['prc']
351
+
352
+ # Filter the data to include only the top 100 wells
353
+ df = df[df['prc'].isin(top_wells)]
354
+
355
+ # Initialize storage for absolute difference data
356
+ diff_data = []
357
+
358
+ # Group by wells and iterate over them
359
+ for i, (prc, group) in enumerate(df.groupby('prc')):
360
+ original_mean = group[settings['score_column']].mean() # Original full-well mean
361
+ max_cells = len(group)
362
+ sample_sizes = np.arange(2, max_cells + 1, increment) # Sample sizes from 2 to max cells
363
+
364
+ # Iterate over sample sizes and compute absolute difference
365
+ for sample_size in sample_sizes:
366
+ abs_diffs = []
367
+
368
+ # Perform multiple random samples to reduce noise
369
+ for _ in range(num_repeats):
370
+ sample = group.sample(n=sample_size, replace=False)
371
+ sampled_mean = sample[settings['score_column']].mean()
372
+ abs_diff = abs(sampled_mean - original_mean) # Absolute difference
373
+ abs_diffs.append(abs_diff)
374
+
375
+ # Compute the average absolute difference across all repeats
376
+ avg_abs_diff = np.mean(abs_diffs)
377
+
378
+ # Store the result for plotting
379
+ diff_data.append((sample_size, avg_abs_diff))
380
+
381
+ # Convert absolute difference data to DataFrame for plotting
382
+ diff_df = pd.DataFrame(diff_data, columns=['sample_size', 'avg_abs_diff'])
383
+
384
+ # Group by sample size to calculate mean and standard deviation
385
+ summary_df = diff_df.groupby('sample_size').agg(
386
+ mean_abs_diff=('avg_abs_diff', 'mean'),
387
+ std_abs_diff=('avg_abs_diff', 'std')
388
+ ).reset_index()
389
+
390
+ # Apply smoothing using a rolling window
391
+ summary_df['smoothed_mean_abs_diff'] = summary_df['mean_abs_diff'].rolling(window=smoothing, min_periods=1).mean()
392
+
393
+ # Convert percentage to fraction
394
+ if isinstance(settings['tolerance'], int):
395
+ tolerance_fraction = settings['tolerance'] / 100 # Convert 2% to 0.02
396
+ elif isinstance(settings['tolerance'], float):
397
+ tolerance_fraction = settings['tolerance']
398
+ else:
399
+ raise ValueError("Tolerance must be an integer 0 - 100 or float 0.0 - 1.0.")
400
+
401
+ # Compute the relative threshold for each well
402
+ relative_thresholds = {
403
+ prc: tolerance_fraction * group[settings['score_column']].mean() # Compute % of original mean
404
+ for prc, group in df.groupby('prc')
405
+ }
406
+
407
+ # Detect the elbow point when mean absolute difference is below the relative threshold
408
+ summary_df['relative_threshold'] = summary_df['sample_size'].map(
409
+ lambda size: np.mean([relative_thresholds[prc] for prc in top_wells]) # Average across selected wells
410
+ )
411
+
412
+ elbow_df = summary_df[summary_df['smoothed_mean_abs_diff'] <= summary_df['relative_threshold']]
413
+
414
+ # Select the first occurrence if it exists; otherwise, use the last point
415
+ if not elbow_df.empty:
416
+ elbow_point = elbow_df.iloc[0] # First point where condition is met
417
+ else:
418
+ elbow_point = summary_df.iloc[-1] # Fallback to last point
419
+
420
+ # Plot the mean absolute difference with standard deviation as shaded area
421
+ fig, ax = plt.subplots(figsize=(10, 10))
422
+ ax.plot(
423
+ summary_df['sample_size'], summary_df['smoothed_mean_abs_diff'], color='teal', label='Smoothed Mean Absolute Difference'
424
+ )
425
+ ax.fill_between(
426
+ summary_df['sample_size'],
427
+ summary_df['smoothed_mean_abs_diff'] - summary_df['std_abs_diff'],
428
+ summary_df['smoothed_mean_abs_diff'] + summary_df['std_abs_diff'],
429
+ color='teal', alpha=0.3, label='±1 Std. Dev.'
430
+ )
431
+
432
+ # Mark the elbow point (inflection) on the plot
433
+ ax.axvline(elbow_point['sample_size'], color='black', linestyle='--', label='Elbow Point')
434
+
435
+ # Formatting the plot
436
+ ax.set_xlabel('Sample Size')
437
+ ax.set_ylabel('Mean Absolute Difference')
438
+ ax.set_title('Mean Absolute Difference vs. Sample Size with Standard Deviation')
439
+ ax.legend().remove()
440
+
441
+ # Save the plot if a destination is provided
442
+ dst = os.path.dirname(settings['count_data'][0])
443
+ if dst is not None:
444
+ fig_path = os.path.join(dst, 'results')
445
+ os.makedirs(fig_path, exist_ok=True)
446
+ fig_file_path = os.path.join(fig_path, 'cell_min_threshold.pdf')
447
+ fig.savefig(fig_file_path, format='pdf', dpi=600, bbox_inches='tight')
448
+ print(f"Saved {fig_file_path}")
449
+
450
+ plt.show()
451
+ return elbow_point['sample_size']
452
+
453
+ def minimum_cell_simulation_v1(settings, num_repeats=10, sample_size=100, tolerance=0.02, smoothing=10, increment=10):
454
+ """
455
+ Plot the mean absolute difference with standard deviation as shaded area vs. sample size.
456
+ Detect and mark the elbow point (inflection) with smoothing and tolerance control.
457
+ """
458
+
459
+ from spacr.utils import correct_metadata_column_names
460
+
461
+ # Load and process data
462
+ if isinstance(settings['score_data'], str):
463
+ settings['score_data'] = [settings['score_data']]
464
+
465
+ dfs = []
466
+ for i, score_data in enumerate(settings['score_data']):
467
+ df = pd.read_csv(score_data)
468
+ df = correct_metadata_column_names(df)
469
+ df['plate'] = f'plate{i + 1}'
470
+
471
+ if 'prc' not in df.columns:
472
+ df['prc'] = df['plate'] + '_' + df['row'].astype(str) + '_' + df['column'].astype(str)
473
+
341
474
  dfs.append(df)
342
475
 
343
476
  df = pd.concat(dfs, axis=0)
@@ -697,8 +830,8 @@ def save_summary_to_file(model, file_path='summary.csv'):
697
830
 
698
831
  def perform_regression(settings):
699
832
 
700
- from .plot import plot_plates
701
- from .utils import merge_regression_res_with_metadata, save_settings
833
+ from .plot import plot_plates, plot_data_from_csv
834
+ from .utils import merge_regression_res_with_metadata, save_settings, calculate_shortest_distance
702
835
  from .settings import get_perform_regression_default_settings
703
836
  from .toxo import go_term_enrichment_by_column, custom_volcano_plot, plot_gene_phenotypes, plot_gene_heatmaps
704
837
  from .sequencing import graph_sequencing_stats
@@ -715,7 +848,9 @@ def perform_regression(settings):
715
848
  df = pd.read_csv(count_data)
716
849
  df['plate_name'] = f'plate{i+1}'
717
850
  if 'column' in df.columns:
718
- df['column_name'] = df['column']
851
+ df.rename(columns={'column': 'column_name'}, inplace=True)
852
+ if 'col' in df.columns:
853
+ df.rename(columns={'col': 'column_name'}, inplace=True)
719
854
  count_data_df = pd.concat([count_data_df, df])
720
855
  print('Count data:', len(count_data_df))
721
856
 
@@ -724,13 +859,16 @@ def perform_regression(settings):
724
859
  df = pd.read_csv(score_data)
725
860
  df['plate_name'] = f'plate{i+1}'
726
861
  if 'column' in df.columns:
727
- df['column_name'] = df['column']
862
+ df.rename(columns={'column': 'column_name'}, inplace=True)
863
+ if 'col' in df.columns:
864
+ df.rename(columns={'col': 'column_name'}, inplace=True)
865
+
728
866
  score_data_df = pd.concat([score_data_df, df])
867
+ display(score_data_df)
729
868
  print('Score data:', len(score_data_df))
730
869
  else:
731
870
  count_data_df = pd.read_csv(settings['count_data'])
732
871
  score_data_df = pd.read_csv(settings['score_data'])
733
-
734
872
  print(f"Dependent variable: {len(score_data_df)}")
735
873
  print(f"Independent variable: {len(count_data_df)}")
736
874
 
@@ -738,7 +876,8 @@ def perform_regression(settings):
738
876
  print(f'Columns in DataFrame:')
739
877
  for col in score_data_df.columns:
740
878
  print(col)
741
- raise ValueError(f"Dependent variable {settings['dependent_variable']} not found in the DataFrame")
879
+ if not settings['dependent_variable'] == 'pathogen_nucleus_shortest_distance':
880
+ raise ValueError(f"Dependent variable {settings['dependent_variable']} not found in the DataFrame")
742
881
 
743
882
  if 'prediction_probability_class_1' in score_data_df.columns:
744
883
  if not settings['class_1_threshold'] is None:
@@ -804,6 +943,75 @@ def perform_regression(settings):
804
943
  return df, n_gene
805
944
  else:
806
945
  return df
946
+
947
+ def grna_metricks(df):
948
+ df[['plate', 'row', 'column']] = df['prc'].str.split('_', expand=True)
949
+
950
+ # --- 2) Compute GRNA-level Well Counts ---
951
+ # For each (grna, plate), count the number of unique prc (wells)
952
+ grna_well_counts = (df.groupby(['grna', 'plate'])['prc'].nunique().reset_index(name='grna_well_count'))
953
+
954
+ # --- 3) Compute Gene-level Well Counts ---
955
+ # For each (gene, plate), count the number of unique prc
956
+ gene_well_counts = (df.groupby(['gene', 'plate'])['prc'].nunique().reset_index(name='gene_well_count'))
957
+
958
+ # --- 4) Merge These Counts into a Single DataFrame ---
959
+ # Because each grna is typically associated with one gene, we bring them together.
960
+ # First, create a unique (grna, gene, plate) reference from the original df
961
+ unique_triplets = df[['grna', 'gene', 'plate']].drop_duplicates()
962
+
963
+ # Merge the grna_well_count
964
+ merged_df = pd.merge(unique_triplets, grna_well_counts, on=['grna', 'plate'], how='left')
965
+
966
+ # Merge the gene_well_count
967
+ merged_df = pd.merge(merged_df, gene_well_counts, on=['gene', 'plate'], how='left')
968
+
969
+ # Keep only the columns needed (if you want to keep 'gene', remove the drop below)
970
+ final_grna_df = merged_df[['grna', 'plate', 'grna_well_count', 'gene_well_count']]
971
+
972
+ # --- 5) Compute gene_count per prc ---
973
+ # For each prc (well), how many distinct genes are there?
974
+ prc_gene_count_df = (df.groupby('prc')['gene'].nunique().reset_index(name='gene_count'))
975
+ prc_gene_count_df[['plate', 'row', 'column']] = prc_gene_count_df['prc'].str.split('_', expand=True)
976
+
977
+ return final_grna_df, prc_gene_count_df
978
+
979
+ def get_outlier_reference_values(df, outlier_col, return_col):
980
+ """
981
+ Detect outliers in 'outlier_col' of 'df' using the 1.5 × IQR rule,
982
+ and return values from 'return_col' that correspond to those outliers.
983
+
984
+ Parameters:
985
+ -----------
986
+ df : pd.DataFrame
987
+ Input DataFrame.
988
+ outlier_col : str
989
+ Column in which to check for outliers.
990
+ return_col : str
991
+ Column whose values to return for rows that are outliers in 'outlier_col'.
992
+
993
+ Returns:
994
+ --------
995
+ pd.Series
996
+ A Series containing values from 'return_col' for the outlier rows.
997
+ """
998
+ # Calculate Q1, Q3, and IQR for the outlier_col
999
+ Q1 = df[outlier_col].quantile(0.05)
1000
+ Q3 = df[outlier_col].quantile(0.95)
1001
+ IQR = Q3 - Q1
1002
+
1003
+ # Determine the outlier cutoffs
1004
+ lower_bound = Q1 - 1.5 * IQR
1005
+ upper_bound = Q3 + 1.5 * IQR
1006
+
1007
+ # Create a mask for outliers
1008
+ outlier_mask = (df[outlier_col] < lower_bound) | (df[outlier_col] > upper_bound)
1009
+
1010
+ outliers = df.loc[outlier_mask, return_col]
1011
+
1012
+ outliers_ls = outliers.unique().tolist()
1013
+
1014
+ return outliers_ls
807
1015
 
808
1016
  settings = get_perform_regression_default_settings(settings)
809
1017
  count_data_df, score_data_df = _perform_regression_read_data(settings)
@@ -839,30 +1047,107 @@ def perform_regression(settings):
839
1047
 
840
1048
  if settings['min_cell_count'] is None:
841
1049
  settings['min_cell_count'] = minimum_cell_simulation(settings, tolerance=settings['tolerance'])
1050
+
842
1051
  print(f"Minimum cell count: {settings['min_cell_count']}")
1052
+ print(f"Dependent variable after minimum cell count filter: {len(score_data_df)}")
1053
+ display(score_data_df)
843
1054
 
844
1055
  orig_dv = settings['dependent_variable']
845
1056
 
846
1057
  dependent_df, dependent_variable = process_scores(score_data_df, settings['dependent_variable'], settings['plate'], settings['min_cell_count'], settings['agg_type'], settings['transform'])
847
1058
  print(f"Dependent variable after process_scores: {len(dependent_df)}")
848
-
1059
+ display(dependent_df)
1060
+
849
1061
  if settings['fraction_threshold'] is None:
850
1062
  settings['fraction_threshold'] = graph_sequencing_stats(settings)
851
1063
 
852
1064
  independent_df = process_reads(count_data_df, settings['fraction_threshold'], settings['plate'], filter_column=filter_column, filter_value=filter_value)
853
1065
  independent_df, n_grna, n_gene = _count_variable_instances(independent_df, column_1='grna', column_2='gene')
854
-
1066
+
855
1067
  print(f"Independent variable after process_reads: {len(independent_df)}")
856
1068
 
857
1069
  merged_df = pd.merge(independent_df, dependent_df, on='prc')
1070
+
1071
+ display(independent_df)
1072
+ display(dependent_df)
858
1073
 
859
- os.makedirs(res_folder, exist_ok=True)
860
- data_path = os.path.join(res_folder, 'regression_data.csv')
861
- merged_df.to_csv(data_path, index=False)
862
- print(f"Saved regression data to {data_path}")
863
-
864
- merged_df[['plate', 'row_name', 'column']] = merged_df['prc'].str.split('_', expand=True)
1074
+ display(merged_df)
1075
+
865
1076
 
1077
+ merged_df[['plate', 'row_name', 'column']] = merged_df['prc'].str.split('_', expand=True)
1078
+
1079
+ try:
1080
+ os.makedirs(res_folder, exist_ok=True)
1081
+ data_path = os.path.join(res_folder, 'regression_data.csv')
1082
+ merged_df.to_csv(data_path, index=False)
1083
+ print(f"Saved regression data to {data_path}")
1084
+
1085
+ cell_settings = {'src':data_path,
1086
+ 'graph_name':'cell_count',
1087
+ 'data_column':['cell_count'],
1088
+ 'grouping_column':'plate',
1089
+ 'graph_type':'jitter_bar',
1090
+ 'theme':'bright',
1091
+ 'save':True,
1092
+ 'y_lim':[None,None],
1093
+ 'log_y':False,
1094
+ 'log_x':False,
1095
+ 'representation':'well',
1096
+ 'verbose':False}
1097
+
1098
+ _, _ = plot_data_from_csv(settings=cell_settings)
1099
+
1100
+ final_grna_df, prc_gene_count_df = grna_metricks(merged_df)
1101
+
1102
+ if settings['outlier_detection']:
1103
+ outliers_grna = get_outlier_reference_values(final_grna_df,outlier_col='grna_well_count',return_col='grna')
1104
+ if len (outliers_grna) > 0:
1105
+ merged_df = merged_df[~merged_df['grna'].isin(outliers_grna)]
1106
+ final_grna_df, prc_gene_count_df = grna_metricks(merged_df)
1107
+ merged_df.to_csv(data_path, index=False)
1108
+ print(f"Saved regression data to {data_path}")
1109
+
1110
+ grna_data_path = os.path.join(res_folder, 'grna_well.csv')
1111
+ final_grna_df.to_csv(grna_data_path, index=False)
1112
+ print(f"Saved grna per well data to {grna_data_path}")
1113
+
1114
+ wells_per_gene_settings = {'src':grna_data_path,
1115
+ 'graph_name':'wells_per_gene',
1116
+ 'data_column':['grna_well_count'],
1117
+ 'grouping_column':'plate',
1118
+ 'graph_type':'jitter_bar',
1119
+ 'theme':'bright',
1120
+ 'save':True,
1121
+ 'y_lim':[None,None],
1122
+ 'log_y':False,
1123
+ 'log_x':False,
1124
+ 'representation':'object',
1125
+ 'verbose':True}
1126
+
1127
+ _, _ = plot_data_from_csv(settings=wells_per_gene_settings)
1128
+
1129
+ grna_well_data_path = os.path.join(res_folder, 'well_grna.csv')
1130
+ prc_gene_count_df.to_csv(grna_well_data_path, index=False)
1131
+ print(f"Saved well per grna data to {grna_well_data_path}")
1132
+
1133
+ grna_per_well_settings = {'src':grna_well_data_path,
1134
+ 'graph_name':'gene_per_well',
1135
+ 'data_column':['gene_count'],
1136
+ 'grouping_column':'plate',
1137
+ 'graph_type':'jitter_bar',
1138
+ 'theme':'bright',
1139
+ 'save':True,
1140
+ 'y_lim':[None,None],
1141
+ 'log_y':False,
1142
+ 'log_x':False,
1143
+ 'representation':'well',
1144
+ 'verbose':False}
1145
+
1146
+ _, _ = plot_data_from_csv(settings=grna_per_well_settings)
1147
+
1148
+ except Exception as e:
1149
+ print(e)
1150
+
866
1151
  _ = plot_plates(merged_df, variable=orig_dv, grouping='mean', min_max='allq', cmap='viridis', min_count=None, dst=res_folder)
867
1152
 
868
1153
  model, coef_df, regression_type = regression(merged_df, csv_path, dependent_variable, settings['regression_type'], settings['alpha'], settings['random_row_column_effects'], nc=settings['negative_control'], pc=settings['positive_control'], controls=settings['controls'], dst=res_folder, cov_type=settings['cov_type'])
@@ -940,6 +1225,8 @@ def perform_regression(settings):
940
1225
  base_dir = os.path.dirname(os.path.abspath(__file__))
941
1226
  metadata_path = os.path.join(base_dir, 'resources', 'data', 'lopit.csv')
942
1227
 
1228
+
1229
+
943
1230
  if settings['volcano'] == 'all':
944
1231
  print('all')
945
1232
  gene_list = custom_volcano_plot(data_path, metadata_path, metadata_column='tagm_location', point_size=600, figsize=20, threshold=reg_threshold, save_path=volcano_path, x_lim=settings['x_lim'],y_lims=settings['y_lims'])
@@ -952,6 +1239,7 @@ def perform_regression(settings):
952
1239
  print('grna')
953
1240
  gene_list = custom_volcano_plot(data_path_grna, metadata_path, metadata_column='tagm_location', point_size=600, figsize=20, threshold=reg_threshold, save_path=volcano_path, x_lim=settings['x_lim'],y_lims=settings['y_lims'])
954
1241
  display(gene_list)
1242
+
955
1243
  phenotype_plot = os.path.join(res_folder,'phenotype_plot.pdf')
956
1244
  transcription_heatmap = os.path.join(res_folder,'transcription_heatmap.pdf')
957
1245
  data_GT1 = pd.read_csv(settings['metadata_files'][1], low_memory=False)
@@ -1006,6 +1294,11 @@ def process_reads(csv_path, fraction_threshold, plate, filter_column=None, filte
1006
1294
  csv_df['plate'] = plate
1007
1295
  else:
1008
1296
  csv_df['plate'] = 'plate1'
1297
+
1298
+ if 'prcfo' in csv_df.columns:
1299
+ #csv_df = csv_df.loc[:, ~csv_df.columns.duplicated()].copy()
1300
+ csv_df[['plate_name', 'row_name', 'column_name', 'field_name', 'object_name']] = csv_df['prcfo'].str.split('_', expand=True)
1301
+ csv_df['prc'] = csv_df['plate_name'].astype(str) + '_' + csv_df['row_name'].astype(str) + '_' + csv_df['column_name'].astype(str)
1009
1302
 
1010
1303
  if isinstance(filter_column, str):
1011
1304
  filter_column = [filter_column]
@@ -1085,32 +1378,44 @@ def clean_controls(df,values, column):
1085
1378
  return df
1086
1379
 
1087
1380
  def process_scores(df, dependent_variable, plate, min_cell_count=25, agg_type='mean', transform=None, regression_type='ols'):
1088
-
1089
- if 'plate_name' in df.columns:
1090
- df.drop(columns=['plate'], inplace=True)
1091
- df = df.rename(columns={'plate_name': 'plate'})
1092
-
1093
- if 'row' in df.columns:
1094
- df = df.rename(columns={'row': 'row_name'})
1095
- if 'col' in df.columns:
1096
- df = df.rename(columns={'row': 'column_name'})
1381
+ from .utils import calculate_shortest_distance
1382
+ df = df.reset_index(drop=True)
1383
+
1384
+ if 'prcfo' in df.columns:
1385
+ df = df.loc[:, ~df.columns.duplicated()].copy()
1386
+ if not all(col in df.columns for col in ['plate_name', 'row_name', 'column_name']):
1387
+ df[['plate_name', 'row_name', 'column_name', 'field_name', 'object_name']] = df['prcfo'].str.split('_', expand=True)
1388
+ if all(col in df.columns for col in ['plate_name', 'row_name', 'column_name']):
1389
+ df['prc'] = df['plate_name'].astype(str) + '_' + df['row_name'].astype(str) + '_' + df['column_name'].astype(str)
1390
+ else:
1391
+ if 'plate_name' in df.columns:
1392
+ df.drop(columns=['plate'], inplace=True)
1393
+ #df = df.rename(columns={'plate_name': 'plate'})
1394
+
1395
+ if 'plate' in df.columns:
1396
+ df['plate_name'] = df['plate']
1397
+
1398
+ if plate is not None:
1399
+ df['plate_name'] = plate
1097
1400
 
1098
- if plate is not None:
1099
- df['plate'] = plate
1100
-
1101
- if 'column_name' not in df.columns:
1102
- df['column_name'] = df['column']
1103
-
1104
- df['prc'] = df['plate'].astype(str) + '_' + df['row_name'].astype(str) + '_' + df['column_name'].astype(str)
1105
-
1106
- display(df)
1107
-
1401
+ if 'row' in df.columns:
1402
+ df = df.rename(columns={'row': 'row_name'})
1403
+
1404
+ if 'col' in df.columns:
1405
+ df = df.rename(columns={'col': 'column_name'})
1406
+
1407
+ if 'column' in df.columns:
1408
+ df = df.rename(columns={'column': 'column_name'})
1108
1409
 
1109
- df = df[['prc', dependent_variable]]
1410
+ df = df.loc[:, ~df.columns.duplicated()].copy()
1411
+ df['prc'] = df['plate_name'].astype(str) + '_' + df['row_name'].astype(str) + '_' + df['column_name'].astype(str)
1110
1412
 
1413
+ df = df[['prc', dependent_variable]]
1111
1414
  # Group by prc and calculate the mean and count of the dependent_variable
1112
1415
  grouped = df.groupby('prc')[dependent_variable]
1113
1416
 
1417
+ display(grouped)
1418
+
1114
1419
  if regression_type != 'poisson':
1115
1420
 
1116
1421
  print(f'Using agg_type: {agg_type}')
@@ -1163,7 +1468,7 @@ def generate_ml_scores(settings):
1163
1468
 
1164
1469
  from .io import _read_and_merge_data, _read_db
1165
1470
  from .plot import plot_plates
1166
- from .utils import get_ml_results_paths, add_column_to_database
1471
+ from .utils import get_ml_results_paths, add_column_to_database, calculate_shortest_distance
1167
1472
  from .settings import set_default_analyze_screen
1168
1473
 
1169
1474
  settings = set_default_analyze_screen(settings)
@@ -1192,6 +1497,11 @@ def generate_ml_scores(settings):
1192
1497
  pathogen_limit=settings['pathogen_limit'])
1193
1498
  df = pd.concat([df, dft])
1194
1499
 
1500
+ try:
1501
+ df = calculate_shortest_distance(df, 'pathogen', 'nucleus')
1502
+ except Exception as e:
1503
+ print(e)
1504
+
1195
1505
  if settings['annotation_column'] is not None:
1196
1506
 
1197
1507
  settings['location_column'] = settings['annotation_column']
@@ -1337,9 +1647,16 @@ def ml_analysis(df, channel_of_interest=3, location_column='column_name', positi
1337
1647
  df = pd.concat([df, df_metadata[location_column]], axis=1)
1338
1648
 
1339
1649
  # Subset the dataframe based on specified column values
1340
- df1 = df[df[location_column] == negative_control].copy()
1341
- df2 = df[df[location_column] == positive_control].copy()
1342
-
1650
+ if isinstance(negative_control, str):
1651
+ df1 = df[df[location_column] == negative_control].copy()
1652
+ elif isinstance(negative_control, list):
1653
+ df1 = df[df[location_column].isin(negative_control)].copy()
1654
+
1655
+ if isinstance(positive_control, str):
1656
+ df2 = df[df[location_column] == positive_control].copy()
1657
+ elif isinstance(positive_control, list):
1658
+ df2 = df[df[location_column].isin(positive_control)].copy()
1659
+
1343
1660
  # Create target variable
1344
1661
  df1['target'] = 0 # Negative control
1345
1662
  df2['target'] = 1 # Positive control
@@ -1442,13 +1759,13 @@ def ml_analysis(df, channel_of_interest=3, location_column='column_name', positi
1442
1759
  for i in range(prediction_probabilities.shape[1]):
1443
1760
  df[f'prediction_probability_class_{i}'] = prediction_probabilities[:, i]
1444
1761
 
1445
- if verbose:
1446
- print("\nFinal Classification Report on Full Dataset:")
1447
- print(classification_report(y, all_predictions))
1762
+ #if verbose:
1763
+ # print("\nFinal Classification Report on Full Dataset:")
1764
+ # print(classification_report(y, all_predictions))
1448
1765
 
1449
1766
  # Generate metrics DataFrame
1450
- final_report_dict = classification_report(y, all_predictions, output_dict=True)
1451
- metrics_df = pd.DataFrame(final_report_dict).transpose()
1767
+ #final_report_dict = classification_report(y, all_predictions, output_dict=True)
1768
+ #metrics_df = pd.DataFrame(final_report_dict).transpose()
1452
1769
 
1453
1770
  else:
1454
1771
  model.fit(X_train, y_train)
@@ -1572,8 +1889,14 @@ def _calculate_similarity(df, features, col_to_compare, val1, val2):
1572
1889
  pandas.DataFrame: DataFrame with similarity scores.
1573
1890
  """
1574
1891
  # Separate positive and negative control wells
1575
- pos_control = df[df[col_to_compare] == val1][features].mean()
1576
- neg_control = df[df[col_to_compare] == val2][features].mean()
1892
+ if isinstance(val1, str):
1893
+ pos_control = df[df[col_to_compare] == val1][features].mean()
1894
+ elif isinstance(val1, list):
1895
+ pos_control = df[df[col_to_compare].isin(val1)][features].mean()
1896
+ if isinstance(val2, str):
1897
+ neg_control = df[df[col_to_compare] == val2][features].mean()
1898
+ elif isinstance(val2, list):
1899
+ neg_control = df[df[col_to_compare].isin(val2)][features].mean()
1577
1900
 
1578
1901
  # Standardize features for Mahalanobis distance
1579
1902
  scaler = StandardScaler()