spacr 0.0.82__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/plot.py CHANGED
@@ -1110,7 +1110,7 @@ def _imshow(img, labels, nrow=20, color='white', fontsize=12):
1110
1110
  idx = i * n_col + j
1111
1111
  if idx < n_images:
1112
1112
  canvas[i * img_height:(i + 1) * img_height, j * img_width:(j + 1) * img_width] = np.transpose(img[idx], (1, 2, 0))
1113
- plt.figure(figsize=(50, 50))
1113
+ fig = plt.figure(figsize=(50, 50))
1114
1114
  plt.imshow(canvas)
1115
1115
  plt.axis("off")
1116
1116
  for i, label in enumerate(labels):
@@ -1119,7 +1119,7 @@ def _imshow(img, labels, nrow=20, color='white', fontsize=12):
1119
1119
  x = col * img_width + 2
1120
1120
  y = row * img_height + 15
1121
1121
  plt.text(x, y, label, color=color, fontsize=fontsize, fontweight='bold')
1122
- plt.show()
1122
+ return fig
1123
1123
 
1124
1124
  def _plot_histograms_and_stats(df):
1125
1125
  conditions = df['condition'].unique()
@@ -1218,8 +1218,7 @@ def generate_plate_heatmap(df, plate_number, variable, grouping, min_max, min_co
1218
1218
  df = df[df['count'] >= min_count]
1219
1219
 
1220
1220
  # Explicitly set observed=True to avoid FutureWarning
1221
- grouped = df.groupby(['row', 'col'], observed=True)
1222
-
1221
+ grouped = df.groupby(['row', 'col'], observed=True) # Group by row and column
1223
1222
 
1224
1223
  if grouping == 'mean':
1225
1224
  plate = grouped[variable].mean().reset_index()
@@ -1245,7 +1244,7 @@ def generate_plate_heatmap(df, plate_number, variable, grouping, min_max, min_co
1245
1244
 
1246
1245
  return plate_map, min_max
1247
1246
 
1248
- def _plot_plates(df, variable, grouping, min_max, cmap, min_count=0):
1247
+ def plot_plates(df, variable, grouping, min_max, cmap, min_count=0, verbose=True):
1249
1248
  plates = df['prc'].str.split('_', expand=True)[0].unique()
1250
1249
  n_rows, n_cols = (len(plates) + 3) // 4, 4
1251
1250
  fig, ax = plt.subplots(n_rows, n_cols, figsize=(40, 5 * n_rows))
@@ -1260,7 +1259,8 @@ def _plot_plates(df, variable, grouping, min_max, cmap, min_count=0):
1260
1259
  fig.delaxes(ax[i])
1261
1260
 
1262
1261
  plt.subplots_adjust(wspace=0.1, hspace=0.4)
1263
- plt.show()
1262
+ if verbose:
1263
+ plt.show()
1264
1264
  return fig
1265
1265
 
1266
1266
  def print_mask_and_flows(stack, mask, flows, overlay=False):
@@ -1470,3 +1470,99 @@ def plot_object_outlines(src, objects=['nucleus','cell','pathogen'], channels=[0
1470
1470
  overlay=True,
1471
1471
  max_nr=10,
1472
1472
  randomize=True)
1473
+
1474
+ def volcano_plot(coef_df, filename='volcano_plot.pdf'):
1475
+ # Create the volcano plot
1476
+ plt.figure(figsize=(10, 6))
1477
+ sns.scatterplot(
1478
+ data=coef_df,
1479
+ x='coefficient',
1480
+ y='-log10(p_value)',
1481
+ hue='highlight',
1482
+ palette={True: 'red', False: 'blue'}
1483
+ )
1484
+ plt.title('Volcano Plot of Coefficients')
1485
+ plt.xlabel('Coefficient')
1486
+ plt.ylabel('-log10(p-value)')
1487
+ plt.axhline(y=-np.log10(0.05), color='red', linestyle='--')
1488
+ plt.legend().remove()
1489
+ plt.savefig(filename, format='pdf')
1490
+ print(f'Saved Volcano plot: {filename}')
1491
+ plt.show()
1492
+
1493
+ def plot_histogram(df, dependent_variable):
1494
+ # Plot histogram of the dependent variable
1495
+ plt.figure(figsize=(10, 6))
1496
+ sns.histplot(df[dependent_variable], kde=True)
1497
+ plt.title(f'Histogram of {dependent_variable}')
1498
+ plt.xlabel(dependent_variable)
1499
+ plt.ylabel('Frequency')
1500
+ plt.show()
1501
+
1502
+ def plot_lorenz_curves(csv_files, remove_keys=['TGGT1_220950_1', 'TGGT1_233460_4']):
1503
+
1504
+ def lorenz_curve(data):
1505
+ """Calculate Lorenz curve."""
1506
+ sorted_data = np.sort(data)
1507
+ cumulative_data = np.cumsum(sorted_data)
1508
+ lorenz_curve = cumulative_data / cumulative_data[-1]
1509
+ lorenz_curve = np.insert(lorenz_curve, 0, 0)
1510
+ return lorenz_curve
1511
+
1512
+ combined_data = []
1513
+
1514
+ plt.figure(figsize=(10, 6))
1515
+
1516
+ for idx, csv_file in enumerate(csv_files):
1517
+ if idx == 1:
1518
+ save_fldr = os.path.dirname(csv_file)
1519
+ save_path = os.path.join(save_fldr, 'lorenz_curve.pdf')
1520
+
1521
+ df = pd.read_csv(csv_file)
1522
+ for remove in remove_keys:
1523
+ df = df[df['key'] != remove]
1524
+
1525
+ values = df['value'].values
1526
+ combined_data.extend(values)
1527
+
1528
+ lorenz = lorenz_curve(values)
1529
+ name = os.path.basename(csv_file)[:3]
1530
+ plt.plot(np.linspace(0, 1, len(lorenz)), lorenz, label=name)
1531
+
1532
+ # Plot combined Lorenz curve
1533
+ combined_lorenz = lorenz_curve(np.array(combined_data))
1534
+ plt.plot(np.linspace(0, 1, len(combined_lorenz)), combined_lorenz, label="Combined Lorenz Curve", linestyle='--', color='black')
1535
+
1536
+ plt.title('Lorenz Curves')
1537
+ plt.xlabel('Cumulative Share of Individuals')
1538
+ plt.ylabel('Cumulative Share of Value')
1539
+ plt.legend()
1540
+ plt.grid(False)
1541
+ plt.savefig(save_path)
1542
+ plt.show()
1543
+
1544
+ def plot_permutation(permutation_df):
1545
+ num_features = len(permutation_df)
1546
+ fig_height = max(8, num_features * 0.3) # Set a minimum height of 8 and adjust height based on number of features
1547
+ fig_width = 10 # Width can be fixed or adjusted similarly
1548
+ font_size = max(10, 12 - num_features * 0.2) # Adjust font size dynamically
1549
+
1550
+ fig, ax = plt.subplots(figsize=(fig_width, fig_height))
1551
+ ax.barh(permutation_df['feature'], permutation_df['importance_mean'], xerr=permutation_df['importance_std'], color="teal", align="center", alpha=0.6)
1552
+ ax.set_xlabel('Permutation Importance', fontsize=font_size)
1553
+ ax.tick_params(axis='both', which='major', labelsize=font_size)
1554
+ plt.tight_layout()
1555
+ return fig
1556
+
1557
+ def plot_feature_importance(feature_importance_df):
1558
+ num_features = len(feature_importance_df)
1559
+ fig_height = max(8, num_features * 0.3) # Set a minimum height of 8 and adjust height based on number of features
1560
+ fig_width = 10 # Width can be fixed or adjusted similarly
1561
+ font_size = max(10, 12 - num_features * 0.2) # Adjust font size dynamically
1562
+
1563
+ fig, ax = plt.subplots(figsize=(fig_width, fig_height))
1564
+ ax.barh(feature_importance_df['feature'], feature_importance_df['importance'], color="blue", align="center", alpha=0.6)
1565
+ ax.set_xlabel('Feature Importance', fontsize=font_size)
1566
+ ax.tick_params(axis='both', which='major', labelsize=font_size)
1567
+ plt.tight_layout()
1568
+ return fig
spacr/sequencing.py CHANGED
@@ -10,6 +10,7 @@ import statsmodels.api as sm
10
10
  from statsmodels.regression.mixed_linear_model import MixedLM
11
11
  from statsmodels.stats.outliers_influence import variance_inflation_factor
12
12
  from scipy.stats import gmean
13
+ from scipy import stats
13
14
  from difflib import SequenceMatcher
14
15
  from collections import Counter
15
16
  from IPython.display import display
@@ -305,13 +306,10 @@ def analyze_reads(settings):
305
306
  qc_df = pd.DataFrame([qc])
306
307
  qc_df.to_csv(qc_file_path, index=False)
307
308
 
308
- settings.setdefault('upstream', 'CTTCTGGTAAATGGGGATGTCAAGTT')
309
- settings.setdefault('downstream', 'GTTTAAGAGCTATGCTGGAAACAGCAG') #This is the reverce compliment of the column primer starting from the end #TGCTGTTTAAGAGCTATGCTGGAAACAGCA
310
- settings.setdefault('barecode_length_1', 8)
311
- settings.setdefault('barecode_length_2', 7)
312
- settings.setdefault('chunk_size', 1000000)
313
- settings.setdefault('test', False)
314
-
309
+ from .utils import get_analyze_reads_default_settings
310
+
311
+ settings = get_analyze_reads_default_settings(settings)
312
+
315
313
  samples_dict = parse_gz_files(settings['src'])
316
314
  combine_reads(samples_dict, settings['src'], settings['chunk_size'], settings['barecode_length_1'], settings['barecode_length_2'], settings['upstream'], settings['downstream'])
317
315
 
@@ -477,17 +475,10 @@ def map_barcodes(h5_file_path, settings={}):
477
475
  #print(f"Max count for {nc} in other columns: {max_count_c3}")
478
476
 
479
477
  return filtered_df
478
+
479
+ from .settings import get_map_barcodes_default_settings
480
480
 
481
- settings.setdefault('grna', '/home/carruthers/Documents/grna_barcodes.csv')
482
- settings.setdefault('barcodes', '/home/carruthers/Documents/SCREEN_BARCODES.csv')
483
- settings.setdefault('plate_dict', {'EO1': 'plate1', 'EO2': 'plate2', 'EO3': 'plate3', 'EO4': 'plate4', 'EO5': 'plate5', 'EO6': 'plate6', 'EO7': 'plate7', 'EO8': 'plate8'})
484
- settings.setdefault('test', False)
485
- settings.setdefault('verbose', True)
486
-
487
- settings.setdefault('pc', 'TGGT1_220950_1')
488
- settings.setdefault('pc_loc', 'c2')
489
- settings.setdefault('nc', 'TGGT1_233460_4')
490
- settings.setdefault('nc_loc', 'c1')
481
+ settings = get_map_barcodes_default_settings(settings)
491
482
 
492
483
  fldr = os.path.splitext(h5_file_path)[0]
493
484
  file_name = os.path.basename(fldr)
@@ -1393,16 +1384,6 @@ def generate_fraction_map(df, gene_column, min_=10, plates=['p1','p2','p3','p4']
1393
1384
  independent_variables.index.name = 'prc'
1394
1385
  independent_variables = independent_variables.loc[:, (independent_variables.sum() != 0)]
1395
1386
  return independent_variables
1396
-
1397
-
1398
- def plot_histogram(df, dependent_variable):
1399
- # Plot histogram of the dependent variable
1400
- plt.figure(figsize=(10, 6))
1401
- sns.histplot(df[dependent_variable], kde=True)
1402
- plt.title(f'Histogram of {dependent_variable}')
1403
- plt.xlabel(dependent_variable)
1404
- plt.ylabel('Frequency')
1405
- plt.show()
1406
1387
 
1407
1388
  def precess_reads(csv_path, fraction_threshold, plate):
1408
1389
  # Read the CSV file into a DataFrame
@@ -1475,32 +1456,42 @@ def check_normality(data, variable_name, verbose=False):
1475
1456
  print(f"The data for {variable_name} is not normally distributed.")
1476
1457
  return False
1477
1458
 
1478
- def process_scores(df, dependent_variable, plate, min_cell_count=25, agg_type='mean', transform=None):
1459
+ def process_scores(df, dependent_variable, plate, min_cell_count=25, agg_type='mean', transform=None, regression_type='ols'):
1479
1460
 
1480
1461
  if plate is not None:
1481
1462
  df['plate'] = plate
1482
1463
 
1464
+ if 'col' not in df.columns:
1465
+ df['col'] = df['column']
1466
+
1483
1467
  df['prc'] = df['plate'] + '_' + df['row'] + '_' + df['col']
1484
1468
  df = df[['prc', dependent_variable]]
1485
1469
 
1486
1470
  # Group by prc and calculate the mean and count of the dependent_variable
1487
1471
  grouped = df.groupby('prc')[dependent_variable]
1488
-
1489
- print(f'Using agg_type: {agg_type}')
1490
- if agg_type == 'median':
1491
- dependent_df = grouped.median().reset_index()
1492
- elif agg_type == 'mean':
1493
- dependent_df = grouped.mean().reset_index()
1494
- elif agg_type == 'quantile':
1495
- dependent_df = grouped.quantile(0.75).reset_index()
1496
- elif agg_type == None:
1497
- dependent_df = df.reset_index()
1498
- if 'prcfo' in dependent_df.columns:
1499
- dependent_df = dependent_df.drop(columns=['prcfo'])
1500
-
1501
- else:
1502
- raise ValueError(f"Unsupported aggregation type {agg_type}")
1503
-
1472
+
1473
+ if regression_type != 'poisson':
1474
+
1475
+ print(f'Using agg_type: {agg_type}')
1476
+
1477
+ if agg_type == 'median':
1478
+ dependent_df = grouped.median().reset_index()
1479
+ elif agg_type == 'mean':
1480
+ dependent_df = grouped.mean().reset_index()
1481
+ elif agg_type == 'quantile':
1482
+ dependent_df = grouped.quantile(0.75).reset_index()
1483
+ elif agg_type == None:
1484
+ dependent_df = df.reset_index()
1485
+ if 'prcfo' in dependent_df.columns:
1486
+ dependent_df = dependent_df.drop(columns=['prcfo'])
1487
+ else:
1488
+ raise ValueError(f"Unsupported aggregation type {agg_type}")
1489
+
1490
+ if regression_type == 'poisson':
1491
+ agg_type = 'count'
1492
+ print(f'Using agg_type: {agg_type} for poisson regression')
1493
+ dependent_df = grouped.sum().reset_index()
1494
+
1504
1495
  # Calculate cell_count for all cases
1505
1496
  cell_count = grouped.size().reset_index(name='cell_count')
1506
1497
 
@@ -1516,7 +1507,7 @@ def process_scores(df, dependent_variable, plate, min_cell_count=25, agg_type='m
1516
1507
  if not transform is None:
1517
1508
  transformer = apply_transformation(dependent_df[dependent_variable], transform=transform)
1518
1509
  transformed_var = f'{transform}_{dependent_variable}'
1519
- df[transformed_var] = transformer.fit_transform(dependent_df[[dependent_variable]])
1510
+ dependent_df[transformed_var] = transformer.fit_transform(dependent_df[[dependent_variable]])
1520
1511
  dependent_variable = transformed_var
1521
1512
  is_normal = check_normality(dependent_df[transformed_var], transformed_var)
1522
1513
 
@@ -1557,6 +1548,7 @@ def regression_model(X, y, regression_type='ols', groups=None, alpha=1.0, remove
1557
1548
  model = sm.GLS(y, X).fit()
1558
1549
 
1559
1550
  elif regression_type == 'wls':
1551
+ weights = 1 / np.sqrt(X.iloc[:, 1])
1560
1552
  model = sm.WLS(y, X, weights=weights).fit()
1561
1553
 
1562
1554
  elif regression_type == 'rlm':
@@ -1589,7 +1581,7 @@ def regression_model(X, y, regression_type='ols', groups=None, alpha=1.0, remove
1589
1581
  model = sm.Probit(y, X).fit()
1590
1582
 
1591
1583
  elif regression_type == 'poisson':
1592
- model_poisson = sm.Poisson(y, X).fit()
1584
+ model = sm.Poisson(y, X).fit()
1593
1585
 
1594
1586
  elif regression_type == 'lasso':
1595
1587
  model = Lasso(alpha=alpha).fit(X, y)
@@ -1611,25 +1603,6 @@ def regression_model(X, y, regression_type='ols', groups=None, alpha=1.0, remove
1611
1603
 
1612
1604
  return model
1613
1605
 
1614
- def volcano_plot(coef_df, filename='volcano_plot.pdf'):
1615
- # Create the volcano plot
1616
- plt.figure(figsize=(10, 6))
1617
- sns.scatterplot(
1618
- data=coef_df,
1619
- x='coefficient',
1620
- y='-log10(p_value)',
1621
- hue='highlight',
1622
- palette={True: 'red', False: 'blue'}
1623
- )
1624
- plt.title('Volcano Plot of Coefficients')
1625
- plt.xlabel('Coefficient')
1626
- plt.ylabel('-log10(p-value)')
1627
- plt.axhline(y=-np.log10(0.05), color='red', linestyle='--')
1628
- plt.legend().remove()
1629
- plt.savefig(filename, format='pdf')
1630
- print(f'Saved Volcano plot: {filename}')
1631
- plt.show()
1632
-
1633
1606
  def clean_controls(df,pc,nc,other):
1634
1607
  if 'col' in df.columns:
1635
1608
  df['column'] = df['col']
@@ -1642,20 +1615,60 @@ def clean_controls(df,pc,nc,other):
1642
1615
  print(f'Removed data from {nc, pc, other}')
1643
1616
  return df
1644
1617
 
1618
+ # Remove outliers by capping values at 1st and 99th percentiles for numerical columns only
1619
+ def remove_outliers(df, low=0.01, high=0.99):
1620
+ numerical_cols = df.select_dtypes(include=[np.number]).columns
1621
+ quantiles = df[numerical_cols].quantile([low, high])
1622
+ for col in numerical_cols:
1623
+ df[col] = np.clip(df[col], quantiles.loc[low, col], quantiles.loc[high, col])
1624
+ return df
1625
+
1626
+ def calculate_p_values(X, y, model):
1627
+ # Predict y values
1628
+ y_pred = model.predict(X)
1629
+
1630
+ # Calculate residuals
1631
+ residuals = y - y_pred
1632
+
1633
+ # Calculate the standard error of the residuals
1634
+ dof = X.shape[0] - X.shape[1] - 1
1635
+ residual_std_error = np.sqrt(np.sum(residuals ** 2) / dof)
1636
+
1637
+ # Calculate the standard error of the coefficients
1638
+ X_design = np.hstack((np.ones((X.shape[0], 1)), X)) # Add intercept
1639
+
1640
+ # Use pseudoinverse instead of inverse to handle singular matrices
1641
+ coef_var_covar = residual_std_error ** 2 * np.linalg.pinv(X_design.T @ X_design)
1642
+ coef_standard_errors = np.sqrt(np.diag(coef_var_covar))
1643
+
1644
+ # Calculate t-statistics
1645
+ t_stats = model.coef_ / coef_standard_errors[1:] # Skip intercept error
1646
+
1647
+ # Calculate p-values
1648
+ p_values = [2 * (1 - stats.t.cdf(np.abs(t), dof)) for t in t_stats]
1649
+
1650
+ return np.array(p_values) # Ensure p_values is a 1-dimensional array
1651
+
1645
1652
  def regression(df, csv_path, dependent_variable='predictions', regression_type=None, alpha=1.0, remove_row_column_effect=False):
1646
1653
 
1654
+ from .plot import volcano_plot, plot_histogram
1655
+
1647
1656
  volcano_filename = os.path.splitext(os.path.basename(csv_path))[0] + '_volcano_plot.pdf'
1648
1657
  volcano_filename = regression_type+'_'+volcano_filename
1649
1658
  if regression_type == 'quantile':
1650
1659
  volcano_filename = str(alpha)+'_'+volcano_filename
1651
1660
  volcano_path=os.path.join(os.path.dirname(csv_path), volcano_filename)
1652
1661
 
1662
+ is_normal = check_normality(df[dependent_variable], dependent_variable)
1663
+
1653
1664
  if regression_type is None:
1654
1665
  if is_normal:
1655
1666
  regression_type = 'ols'
1656
1667
  else:
1657
1668
  regression_type = 'glm'
1658
1669
 
1670
+ #df = remove_outliers(df)
1671
+
1659
1672
  if remove_row_column_effect:
1660
1673
 
1661
1674
  ## 1. Fit the initial model with row and column to estimate their effects
@@ -1714,7 +1727,7 @@ def regression(df, csv_path, dependent_variable='predictions', regression_type=N
1714
1727
  model = regression_model(X, y, regression_type=regression_type, groups=groups, alpha=alpha, remove_row_column_effect=remove_row_column_effect)
1715
1728
 
1716
1729
  # Get the model coefficients and p-values
1717
- if regression_type in ['ols','gls','wls','rlm','glm','mixed','quantile','logit','probit','poisson','lasso','ridge']:
1730
+ if regression_type in ['ols','gls','wls','rlm','glm','mixed','quantile','logit','probit','poisson']:
1718
1731
  coefs = model.params
1719
1732
  p_values = model.pvalues
1720
1733
 
@@ -1723,6 +1736,18 @@ def regression(df, csv_path, dependent_variable='predictions', regression_type=N
1723
1736
  'coefficient': coefs.values,
1724
1737
  'p_value': p_values.values
1725
1738
  })
1739
+ elif regression_type in ['ridge', 'lasso']:
1740
+ coefs = model.coef_
1741
+ coefs = np.array(coefs).flatten()
1742
+ # Calculate p-values
1743
+ p_values = calculate_p_values(X, y, model)
1744
+ p_values = np.array(p_values).flatten()
1745
+
1746
+ # Create a DataFrame for the coefficients and p-values
1747
+ coef_df = pd.DataFrame({
1748
+ 'feature': X.columns,
1749
+ 'coefficient': coefs,
1750
+ 'p_value': p_values})
1726
1751
  else:
1727
1752
  coefs = model.coef_
1728
1753
  intercept = model.intercept_
@@ -1745,31 +1770,31 @@ def regression(df, csv_path, dependent_variable='predictions', regression_type=N
1745
1770
 
1746
1771
  return model, coef_df
1747
1772
 
1748
- def set_regression_defaults(settings):
1749
- settings.setdefault('gene_weights_csv', '/nas_mnt/carruthers/Einar/mitoscreen/sequencing/combined_reads/EO1_combined/EO1_combined_combination_counts.csv')
1750
- settings.setdefault('dependent_variable','predictions')
1751
- settings.setdefault('transform',None)
1752
- settings.setdefault('agg_type','mean')
1753
- settings.setdefault('min_cell_count',25)
1754
- settings.setdefault('regression_type','ols')
1755
- settings.setdefault('remove_row_column_effect',False)
1756
- settings.setdefault('alpha',1)
1757
- settings.setdefault('fraction_threshold',0.1)
1758
- settings.setdefault('nc','c1')
1759
- settings.setdefault('pc','c2')
1760
- settings.setdefault('other','c3')
1761
- settings.setdefault('plate','plate1')
1762
-
1763
- if settings['regression_type'] == 'quantile':
1764
- print(f"Using alpha as quantile for quantile regression, alpha: {settings['alpha']}")
1765
- settings['agg_type'] = None
1766
- print(f'agg_type set to None for quantile regression')
1767
- return settings
1768
-
1769
1773
  def perform_regression(df, settings):
1774
+
1775
+ from spacr.plot import plot_plates
1776
+ from .utils import merge_regression_res_with_metadata
1777
+ from .settings import get_perform_regression_default_settings
1778
+
1779
+ reg_types = ['ols','gls','wls','rlm','glm','mixed','quantile','logit','probit','poisson','lasso','ridge']
1780
+ if settings['regression_type'] not in reg_types:
1781
+ print(f'Possible regression types: {reg_types}')
1782
+ raise ValueError(f"Unsupported regression type {settings['regression_type']}")
1783
+
1784
+ if isinstance(df, str):
1785
+ df = pd.read_csv(df)
1786
+ elif isinstance(df, pd.DataFrame):
1787
+ pass
1788
+ else:
1789
+ raise ValueError("Data must be a DataFrame or a path to a CSV file")
1770
1790
 
1771
- from spacr.plot import _plot_plates
1772
1791
 
1792
+ if settings['dependent_variable'] not in df.columns:
1793
+ print(f'Columns in DataFrame:')
1794
+ for col in df.columns:
1795
+ print(col)
1796
+ raise ValueError(f"Dependent variable {settings['dependent_variable']} not found in the DataFrame")
1797
+
1773
1798
  results_filename = os.path.splitext(os.path.basename(settings['gene_weights_csv']))[0] + '_results.csv'
1774
1799
  hits_filename = os.path.splitext(os.path.basename(settings['gene_weights_csv']))[0] + '_results_significant.csv'
1775
1800
 
@@ -1781,10 +1806,22 @@ def perform_regression(df, settings):
1781
1806
  results_path=os.path.join(os.path.dirname(settings['gene_weights_csv']), results_filename)
1782
1807
  hits_path=os.path.join(os.path.dirname(settings['gene_weights_csv']), hits_filename)
1783
1808
 
1784
- settings = set_regression_defaults(settings)
1809
+ settings = get_perform_regression_default_settings(settings)
1810
+
1811
+ settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
1812
+ settings_dir = os.path.dirname(settings['gene_weights_csv'])
1813
+ settings_csv = os.path.join(settings_dir,f"{settings['regression_type']}_regression_settings.csv")
1814
+ settings_df.to_csv(settings_csv, index=False)
1815
+ display(settings_df)
1785
1816
 
1786
1817
  df = clean_controls(df,settings['pc'],settings['nc'],settings['other'])
1818
+
1819
+ if 'prediction_probability_class_1' in df.columns:
1820
+ if not settings['class_1_threshold'] is None:
1821
+ df['predictions'] = (df['prediction_probability_class_1'] >= settings['class_1_threshold']).astype(int)
1822
+
1787
1823
  dependent_df, dependent_variable = process_scores(df, settings['dependent_variable'], settings['plate'], settings['min_cell_count'], settings['agg_type'], settings['transform'])
1824
+
1788
1825
  display(dependent_df)
1789
1826
 
1790
1827
  independent_df = precess_reads(settings['gene_weights_csv'], settings['fraction_threshold'], settings['plate'])
@@ -1793,8 +1830,9 @@ def perform_regression(df, settings):
1793
1830
  merged_df = pd.merge(independent_df, dependent_df, on='prc')
1794
1831
 
1795
1832
  merged_df[['plate', 'row', 'column']] = merged_df['prc'].str.split('_', expand=True)
1796
-
1797
- plate_heatmap = _plot_plates(df, variable=dependent_variable, grouping='mean', min_max='allq', cmap='viridis', min_count=settings['min_cell_count'])
1833
+
1834
+ if settings['transform'] is None:
1835
+ _ = plot_plates(df, variable=dependent_variable, grouping='mean', min_max='allq', cmap='viridis', min_count=settings['min_cell_count'])
1798
1836
 
1799
1837
  model, coef_df = regression(merged_df, settings['gene_weights_csv'], dependent_variable, settings['regression_type'], settings['alpha'], settings['remove_row_column_effect'])
1800
1838
 
@@ -1813,6 +1851,17 @@ def perform_regression(df, settings):
1813
1851
  print(model.summary())
1814
1852
 
1815
1853
  significant.to_csv(hits_path, index=False)
1854
+
1855
+ me49 = '/home/carruthers/Documents/TGME49_Summary.csv'
1856
+ gt1 = '/home/carruthers/Documents/TGGT1_Summary.csv'
1857
+
1858
+ _ = merge_regression_res_with_metadata(hits_path, me49, name='_me49_metadata')
1859
+ _ = merge_regression_res_with_metadata(hits_path, gt1, name='_gt1_metadata')
1860
+ _ = merge_regression_res_with_metadata(results_path, me49, name='_me49_metadata')
1861
+ _ = merge_regression_res_with_metadata(results_path, gt1, name='_gt1_metadata')
1862
+
1816
1863
  print('Significant Genes')
1817
1864
  display(significant)
1818
1865
  return coef_df
1866
+
1867
+