spacr 0.0.82__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/__init__.py +4 -0
- spacr/annotate_app.py +4 -0
- spacr/annotate_app_v2.py +511 -0
- spacr/core.py +254 -172
- spacr/deep_spacr.py +137 -50
- spacr/graph_learning.py +28 -8
- spacr/gui.py +5 -5
- spacr/gui_2.py +106 -36
- spacr/gui_classify_app.py +3 -3
- spacr/gui_mask_app.py +34 -11
- spacr/gui_measure_app.py +32 -17
- spacr/gui_utils.py +96 -29
- spacr/io.py +227 -144
- spacr/measure.py +2 -1
- spacr/models/cp/toxo_plaque_cyto_e25000_X1120_Y1120.CP_model +0 -0
- spacr/models/cp/toxo_plaque_cyto_e25000_X1120_Y1120.CP_model_settings.csv +23 -0
- spacr/plot.py +102 -6
- spacr/sequencing.py +140 -91
- spacr/settings.py +477 -0
- spacr/timelapse.py +0 -3
- spacr/utils.py +312 -275
- {spacr-0.0.82.dist-info → spacr-0.1.1.dist-info}/METADATA +1 -1
- spacr-0.1.1.dist-info/RECORD +40 -0
- spacr-0.0.82.dist-info/RECORD +0 -36
- {spacr-0.0.82.dist-info → spacr-0.1.1.dist-info}/LICENSE +0 -0
- {spacr-0.0.82.dist-info → spacr-0.1.1.dist-info}/WHEEL +0 -0
- {spacr-0.0.82.dist-info → spacr-0.1.1.dist-info}/entry_points.txt +0 -0
- {spacr-0.0.82.dist-info → spacr-0.1.1.dist-info}/top_level.txt +0 -0
spacr/plot.py
CHANGED
@@ -1110,7 +1110,7 @@ def _imshow(img, labels, nrow=20, color='white', fontsize=12):
|
|
1110
1110
|
idx = i * n_col + j
|
1111
1111
|
if idx < n_images:
|
1112
1112
|
canvas[i * img_height:(i + 1) * img_height, j * img_width:(j + 1) * img_width] = np.transpose(img[idx], (1, 2, 0))
|
1113
|
-
plt.figure(figsize=(50, 50))
|
1113
|
+
fig = plt.figure(figsize=(50, 50))
|
1114
1114
|
plt.imshow(canvas)
|
1115
1115
|
plt.axis("off")
|
1116
1116
|
for i, label in enumerate(labels):
|
@@ -1119,7 +1119,7 @@ def _imshow(img, labels, nrow=20, color='white', fontsize=12):
|
|
1119
1119
|
x = col * img_width + 2
|
1120
1120
|
y = row * img_height + 15
|
1121
1121
|
plt.text(x, y, label, color=color, fontsize=fontsize, fontweight='bold')
|
1122
|
-
|
1122
|
+
return fig
|
1123
1123
|
|
1124
1124
|
def _plot_histograms_and_stats(df):
|
1125
1125
|
conditions = df['condition'].unique()
|
@@ -1218,8 +1218,7 @@ def generate_plate_heatmap(df, plate_number, variable, grouping, min_max, min_co
|
|
1218
1218
|
df = df[df['count'] >= min_count]
|
1219
1219
|
|
1220
1220
|
# Explicitly set observed=True to avoid FutureWarning
|
1221
|
-
grouped = df.groupby(['row', 'col'], observed=True)
|
1222
|
-
|
1221
|
+
grouped = df.groupby(['row', 'col'], observed=True) # Group by row and column
|
1223
1222
|
|
1224
1223
|
if grouping == 'mean':
|
1225
1224
|
plate = grouped[variable].mean().reset_index()
|
@@ -1245,7 +1244,7 @@ def generate_plate_heatmap(df, plate_number, variable, grouping, min_max, min_co
|
|
1245
1244
|
|
1246
1245
|
return plate_map, min_max
|
1247
1246
|
|
1248
|
-
def
|
1247
|
+
def plot_plates(df, variable, grouping, min_max, cmap, min_count=0, verbose=True):
|
1249
1248
|
plates = df['prc'].str.split('_', expand=True)[0].unique()
|
1250
1249
|
n_rows, n_cols = (len(plates) + 3) // 4, 4
|
1251
1250
|
fig, ax = plt.subplots(n_rows, n_cols, figsize=(40, 5 * n_rows))
|
@@ -1260,7 +1259,8 @@ def _plot_plates(df, variable, grouping, min_max, cmap, min_count=0):
|
|
1260
1259
|
fig.delaxes(ax[i])
|
1261
1260
|
|
1262
1261
|
plt.subplots_adjust(wspace=0.1, hspace=0.4)
|
1263
|
-
|
1262
|
+
if verbose:
|
1263
|
+
plt.show()
|
1264
1264
|
return fig
|
1265
1265
|
|
1266
1266
|
def print_mask_and_flows(stack, mask, flows, overlay=False):
|
@@ -1470,3 +1470,99 @@ def plot_object_outlines(src, objects=['nucleus','cell','pathogen'], channels=[0
|
|
1470
1470
|
overlay=True,
|
1471
1471
|
max_nr=10,
|
1472
1472
|
randomize=True)
|
1473
|
+
|
1474
|
+
def volcano_plot(coef_df, filename='volcano_plot.pdf'):
|
1475
|
+
# Create the volcano plot
|
1476
|
+
plt.figure(figsize=(10, 6))
|
1477
|
+
sns.scatterplot(
|
1478
|
+
data=coef_df,
|
1479
|
+
x='coefficient',
|
1480
|
+
y='-log10(p_value)',
|
1481
|
+
hue='highlight',
|
1482
|
+
palette={True: 'red', False: 'blue'}
|
1483
|
+
)
|
1484
|
+
plt.title('Volcano Plot of Coefficients')
|
1485
|
+
plt.xlabel('Coefficient')
|
1486
|
+
plt.ylabel('-log10(p-value)')
|
1487
|
+
plt.axhline(y=-np.log10(0.05), color='red', linestyle='--')
|
1488
|
+
plt.legend().remove()
|
1489
|
+
plt.savefig(filename, format='pdf')
|
1490
|
+
print(f'Saved Volcano plot: {filename}')
|
1491
|
+
plt.show()
|
1492
|
+
|
1493
|
+
def plot_histogram(df, dependent_variable):
|
1494
|
+
# Plot histogram of the dependent variable
|
1495
|
+
plt.figure(figsize=(10, 6))
|
1496
|
+
sns.histplot(df[dependent_variable], kde=True)
|
1497
|
+
plt.title(f'Histogram of {dependent_variable}')
|
1498
|
+
plt.xlabel(dependent_variable)
|
1499
|
+
plt.ylabel('Frequency')
|
1500
|
+
plt.show()
|
1501
|
+
|
1502
|
+
def plot_lorenz_curves(csv_files, remove_keys=['TGGT1_220950_1', 'TGGT1_233460_4']):
|
1503
|
+
|
1504
|
+
def lorenz_curve(data):
|
1505
|
+
"""Calculate Lorenz curve."""
|
1506
|
+
sorted_data = np.sort(data)
|
1507
|
+
cumulative_data = np.cumsum(sorted_data)
|
1508
|
+
lorenz_curve = cumulative_data / cumulative_data[-1]
|
1509
|
+
lorenz_curve = np.insert(lorenz_curve, 0, 0)
|
1510
|
+
return lorenz_curve
|
1511
|
+
|
1512
|
+
combined_data = []
|
1513
|
+
|
1514
|
+
plt.figure(figsize=(10, 6))
|
1515
|
+
|
1516
|
+
for idx, csv_file in enumerate(csv_files):
|
1517
|
+
if idx == 1:
|
1518
|
+
save_fldr = os.path.dirname(csv_file)
|
1519
|
+
save_path = os.path.join(save_fldr, 'lorenz_curve.pdf')
|
1520
|
+
|
1521
|
+
df = pd.read_csv(csv_file)
|
1522
|
+
for remove in remove_keys:
|
1523
|
+
df = df[df['key'] != remove]
|
1524
|
+
|
1525
|
+
values = df['value'].values
|
1526
|
+
combined_data.extend(values)
|
1527
|
+
|
1528
|
+
lorenz = lorenz_curve(values)
|
1529
|
+
name = os.path.basename(csv_file)[:3]
|
1530
|
+
plt.plot(np.linspace(0, 1, len(lorenz)), lorenz, label=name)
|
1531
|
+
|
1532
|
+
# Plot combined Lorenz curve
|
1533
|
+
combined_lorenz = lorenz_curve(np.array(combined_data))
|
1534
|
+
plt.plot(np.linspace(0, 1, len(combined_lorenz)), combined_lorenz, label="Combined Lorenz Curve", linestyle='--', color='black')
|
1535
|
+
|
1536
|
+
plt.title('Lorenz Curves')
|
1537
|
+
plt.xlabel('Cumulative Share of Individuals')
|
1538
|
+
plt.ylabel('Cumulative Share of Value')
|
1539
|
+
plt.legend()
|
1540
|
+
plt.grid(False)
|
1541
|
+
plt.savefig(save_path)
|
1542
|
+
plt.show()
|
1543
|
+
|
1544
|
+
def plot_permutation(permutation_df):
|
1545
|
+
num_features = len(permutation_df)
|
1546
|
+
fig_height = max(8, num_features * 0.3) # Set a minimum height of 8 and adjust height based on number of features
|
1547
|
+
fig_width = 10 # Width can be fixed or adjusted similarly
|
1548
|
+
font_size = max(10, 12 - num_features * 0.2) # Adjust font size dynamically
|
1549
|
+
|
1550
|
+
fig, ax = plt.subplots(figsize=(fig_width, fig_height))
|
1551
|
+
ax.barh(permutation_df['feature'], permutation_df['importance_mean'], xerr=permutation_df['importance_std'], color="teal", align="center", alpha=0.6)
|
1552
|
+
ax.set_xlabel('Permutation Importance', fontsize=font_size)
|
1553
|
+
ax.tick_params(axis='both', which='major', labelsize=font_size)
|
1554
|
+
plt.tight_layout()
|
1555
|
+
return fig
|
1556
|
+
|
1557
|
+
def plot_feature_importance(feature_importance_df):
|
1558
|
+
num_features = len(feature_importance_df)
|
1559
|
+
fig_height = max(8, num_features * 0.3) # Set a minimum height of 8 and adjust height based on number of features
|
1560
|
+
fig_width = 10 # Width can be fixed or adjusted similarly
|
1561
|
+
font_size = max(10, 12 - num_features * 0.2) # Adjust font size dynamically
|
1562
|
+
|
1563
|
+
fig, ax = plt.subplots(figsize=(fig_width, fig_height))
|
1564
|
+
ax.barh(feature_importance_df['feature'], feature_importance_df['importance'], color="blue", align="center", alpha=0.6)
|
1565
|
+
ax.set_xlabel('Feature Importance', fontsize=font_size)
|
1566
|
+
ax.tick_params(axis='both', which='major', labelsize=font_size)
|
1567
|
+
plt.tight_layout()
|
1568
|
+
return fig
|
spacr/sequencing.py
CHANGED
@@ -10,6 +10,7 @@ import statsmodels.api as sm
|
|
10
10
|
from statsmodels.regression.mixed_linear_model import MixedLM
|
11
11
|
from statsmodels.stats.outliers_influence import variance_inflation_factor
|
12
12
|
from scipy.stats import gmean
|
13
|
+
from scipy import stats
|
13
14
|
from difflib import SequenceMatcher
|
14
15
|
from collections import Counter
|
15
16
|
from IPython.display import display
|
@@ -305,13 +306,10 @@ def analyze_reads(settings):
|
|
305
306
|
qc_df = pd.DataFrame([qc])
|
306
307
|
qc_df.to_csv(qc_file_path, index=False)
|
307
308
|
|
308
|
-
|
309
|
-
|
310
|
-
settings
|
311
|
-
|
312
|
-
settings.setdefault('chunk_size', 1000000)
|
313
|
-
settings.setdefault('test', False)
|
314
|
-
|
309
|
+
from .utils import get_analyze_reads_default_settings
|
310
|
+
|
311
|
+
settings = get_analyze_reads_default_settings(settings)
|
312
|
+
|
315
313
|
samples_dict = parse_gz_files(settings['src'])
|
316
314
|
combine_reads(samples_dict, settings['src'], settings['chunk_size'], settings['barecode_length_1'], settings['barecode_length_2'], settings['upstream'], settings['downstream'])
|
317
315
|
|
@@ -477,17 +475,10 @@ def map_barcodes(h5_file_path, settings={}):
|
|
477
475
|
#print(f"Max count for {nc} in other columns: {max_count_c3}")
|
478
476
|
|
479
477
|
return filtered_df
|
478
|
+
|
479
|
+
from .settings import get_map_barcodes_default_settings
|
480
480
|
|
481
|
-
settings
|
482
|
-
settings.setdefault('barcodes', '/home/carruthers/Documents/SCREEN_BARCODES.csv')
|
483
|
-
settings.setdefault('plate_dict', {'EO1': 'plate1', 'EO2': 'plate2', 'EO3': 'plate3', 'EO4': 'plate4', 'EO5': 'plate5', 'EO6': 'plate6', 'EO7': 'plate7', 'EO8': 'plate8'})
|
484
|
-
settings.setdefault('test', False)
|
485
|
-
settings.setdefault('verbose', True)
|
486
|
-
|
487
|
-
settings.setdefault('pc', 'TGGT1_220950_1')
|
488
|
-
settings.setdefault('pc_loc', 'c2')
|
489
|
-
settings.setdefault('nc', 'TGGT1_233460_4')
|
490
|
-
settings.setdefault('nc_loc', 'c1')
|
481
|
+
settings = get_map_barcodes_default_settings(settings)
|
491
482
|
|
492
483
|
fldr = os.path.splitext(h5_file_path)[0]
|
493
484
|
file_name = os.path.basename(fldr)
|
@@ -1393,16 +1384,6 @@ def generate_fraction_map(df, gene_column, min_=10, plates=['p1','p2','p3','p4']
|
|
1393
1384
|
independent_variables.index.name = 'prc'
|
1394
1385
|
independent_variables = independent_variables.loc[:, (independent_variables.sum() != 0)]
|
1395
1386
|
return independent_variables
|
1396
|
-
|
1397
|
-
|
1398
|
-
def plot_histogram(df, dependent_variable):
|
1399
|
-
# Plot histogram of the dependent variable
|
1400
|
-
plt.figure(figsize=(10, 6))
|
1401
|
-
sns.histplot(df[dependent_variable], kde=True)
|
1402
|
-
plt.title(f'Histogram of {dependent_variable}')
|
1403
|
-
plt.xlabel(dependent_variable)
|
1404
|
-
plt.ylabel('Frequency')
|
1405
|
-
plt.show()
|
1406
1387
|
|
1407
1388
|
def precess_reads(csv_path, fraction_threshold, plate):
|
1408
1389
|
# Read the CSV file into a DataFrame
|
@@ -1475,32 +1456,42 @@ def check_normality(data, variable_name, verbose=False):
|
|
1475
1456
|
print(f"The data for {variable_name} is not normally distributed.")
|
1476
1457
|
return False
|
1477
1458
|
|
1478
|
-
def process_scores(df, dependent_variable, plate, min_cell_count=25, agg_type='mean', transform=None):
|
1459
|
+
def process_scores(df, dependent_variable, plate, min_cell_count=25, agg_type='mean', transform=None, regression_type='ols'):
|
1479
1460
|
|
1480
1461
|
if plate is not None:
|
1481
1462
|
df['plate'] = plate
|
1482
1463
|
|
1464
|
+
if 'col' not in df.columns:
|
1465
|
+
df['col'] = df['column']
|
1466
|
+
|
1483
1467
|
df['prc'] = df['plate'] + '_' + df['row'] + '_' + df['col']
|
1484
1468
|
df = df[['prc', dependent_variable]]
|
1485
1469
|
|
1486
1470
|
# Group by prc and calculate the mean and count of the dependent_variable
|
1487
1471
|
grouped = df.groupby('prc')[dependent_variable]
|
1488
|
-
|
1489
|
-
|
1490
|
-
|
1491
|
-
|
1492
|
-
|
1493
|
-
|
1494
|
-
|
1495
|
-
|
1496
|
-
|
1497
|
-
|
1498
|
-
|
1499
|
-
|
1500
|
-
|
1501
|
-
|
1502
|
-
|
1503
|
-
|
1472
|
+
|
1473
|
+
if regression_type != 'poisson':
|
1474
|
+
|
1475
|
+
print(f'Using agg_type: {agg_type}')
|
1476
|
+
|
1477
|
+
if agg_type == 'median':
|
1478
|
+
dependent_df = grouped.median().reset_index()
|
1479
|
+
elif agg_type == 'mean':
|
1480
|
+
dependent_df = grouped.mean().reset_index()
|
1481
|
+
elif agg_type == 'quantile':
|
1482
|
+
dependent_df = grouped.quantile(0.75).reset_index()
|
1483
|
+
elif agg_type == None:
|
1484
|
+
dependent_df = df.reset_index()
|
1485
|
+
if 'prcfo' in dependent_df.columns:
|
1486
|
+
dependent_df = dependent_df.drop(columns=['prcfo'])
|
1487
|
+
else:
|
1488
|
+
raise ValueError(f"Unsupported aggregation type {agg_type}")
|
1489
|
+
|
1490
|
+
if regression_type == 'poisson':
|
1491
|
+
agg_type = 'count'
|
1492
|
+
print(f'Using agg_type: {agg_type} for poisson regression')
|
1493
|
+
dependent_df = grouped.sum().reset_index()
|
1494
|
+
|
1504
1495
|
# Calculate cell_count for all cases
|
1505
1496
|
cell_count = grouped.size().reset_index(name='cell_count')
|
1506
1497
|
|
@@ -1516,7 +1507,7 @@ def process_scores(df, dependent_variable, plate, min_cell_count=25, agg_type='m
|
|
1516
1507
|
if not transform is None:
|
1517
1508
|
transformer = apply_transformation(dependent_df[dependent_variable], transform=transform)
|
1518
1509
|
transformed_var = f'{transform}_{dependent_variable}'
|
1519
|
-
|
1510
|
+
dependent_df[transformed_var] = transformer.fit_transform(dependent_df[[dependent_variable]])
|
1520
1511
|
dependent_variable = transformed_var
|
1521
1512
|
is_normal = check_normality(dependent_df[transformed_var], transformed_var)
|
1522
1513
|
|
@@ -1557,6 +1548,7 @@ def regression_model(X, y, regression_type='ols', groups=None, alpha=1.0, remove
|
|
1557
1548
|
model = sm.GLS(y, X).fit()
|
1558
1549
|
|
1559
1550
|
elif regression_type == 'wls':
|
1551
|
+
weights = 1 / np.sqrt(X.iloc[:, 1])
|
1560
1552
|
model = sm.WLS(y, X, weights=weights).fit()
|
1561
1553
|
|
1562
1554
|
elif regression_type == 'rlm':
|
@@ -1589,7 +1581,7 @@ def regression_model(X, y, regression_type='ols', groups=None, alpha=1.0, remove
|
|
1589
1581
|
model = sm.Probit(y, X).fit()
|
1590
1582
|
|
1591
1583
|
elif regression_type == 'poisson':
|
1592
|
-
|
1584
|
+
model = sm.Poisson(y, X).fit()
|
1593
1585
|
|
1594
1586
|
elif regression_type == 'lasso':
|
1595
1587
|
model = Lasso(alpha=alpha).fit(X, y)
|
@@ -1611,25 +1603,6 @@ def regression_model(X, y, regression_type='ols', groups=None, alpha=1.0, remove
|
|
1611
1603
|
|
1612
1604
|
return model
|
1613
1605
|
|
1614
|
-
def volcano_plot(coef_df, filename='volcano_plot.pdf'):
|
1615
|
-
# Create the volcano plot
|
1616
|
-
plt.figure(figsize=(10, 6))
|
1617
|
-
sns.scatterplot(
|
1618
|
-
data=coef_df,
|
1619
|
-
x='coefficient',
|
1620
|
-
y='-log10(p_value)',
|
1621
|
-
hue='highlight',
|
1622
|
-
palette={True: 'red', False: 'blue'}
|
1623
|
-
)
|
1624
|
-
plt.title('Volcano Plot of Coefficients')
|
1625
|
-
plt.xlabel('Coefficient')
|
1626
|
-
plt.ylabel('-log10(p-value)')
|
1627
|
-
plt.axhline(y=-np.log10(0.05), color='red', linestyle='--')
|
1628
|
-
plt.legend().remove()
|
1629
|
-
plt.savefig(filename, format='pdf')
|
1630
|
-
print(f'Saved Volcano plot: {filename}')
|
1631
|
-
plt.show()
|
1632
|
-
|
1633
1606
|
def clean_controls(df,pc,nc,other):
|
1634
1607
|
if 'col' in df.columns:
|
1635
1608
|
df['column'] = df['col']
|
@@ -1642,20 +1615,60 @@ def clean_controls(df,pc,nc,other):
|
|
1642
1615
|
print(f'Removed data from {nc, pc, other}')
|
1643
1616
|
return df
|
1644
1617
|
|
1618
|
+
# Remove outliers by capping values at 1st and 99th percentiles for numerical columns only
|
1619
|
+
def remove_outliers(df, low=0.01, high=0.99):
|
1620
|
+
numerical_cols = df.select_dtypes(include=[np.number]).columns
|
1621
|
+
quantiles = df[numerical_cols].quantile([low, high])
|
1622
|
+
for col in numerical_cols:
|
1623
|
+
df[col] = np.clip(df[col], quantiles.loc[low, col], quantiles.loc[high, col])
|
1624
|
+
return df
|
1625
|
+
|
1626
|
+
def calculate_p_values(X, y, model):
|
1627
|
+
# Predict y values
|
1628
|
+
y_pred = model.predict(X)
|
1629
|
+
|
1630
|
+
# Calculate residuals
|
1631
|
+
residuals = y - y_pred
|
1632
|
+
|
1633
|
+
# Calculate the standard error of the residuals
|
1634
|
+
dof = X.shape[0] - X.shape[1] - 1
|
1635
|
+
residual_std_error = np.sqrt(np.sum(residuals ** 2) / dof)
|
1636
|
+
|
1637
|
+
# Calculate the standard error of the coefficients
|
1638
|
+
X_design = np.hstack((np.ones((X.shape[0], 1)), X)) # Add intercept
|
1639
|
+
|
1640
|
+
# Use pseudoinverse instead of inverse to handle singular matrices
|
1641
|
+
coef_var_covar = residual_std_error ** 2 * np.linalg.pinv(X_design.T @ X_design)
|
1642
|
+
coef_standard_errors = np.sqrt(np.diag(coef_var_covar))
|
1643
|
+
|
1644
|
+
# Calculate t-statistics
|
1645
|
+
t_stats = model.coef_ / coef_standard_errors[1:] # Skip intercept error
|
1646
|
+
|
1647
|
+
# Calculate p-values
|
1648
|
+
p_values = [2 * (1 - stats.t.cdf(np.abs(t), dof)) for t in t_stats]
|
1649
|
+
|
1650
|
+
return np.array(p_values) # Ensure p_values is a 1-dimensional array
|
1651
|
+
|
1645
1652
|
def regression(df, csv_path, dependent_variable='predictions', regression_type=None, alpha=1.0, remove_row_column_effect=False):
|
1646
1653
|
|
1654
|
+
from .plot import volcano_plot, plot_histogram
|
1655
|
+
|
1647
1656
|
volcano_filename = os.path.splitext(os.path.basename(csv_path))[0] + '_volcano_plot.pdf'
|
1648
1657
|
volcano_filename = regression_type+'_'+volcano_filename
|
1649
1658
|
if regression_type == 'quantile':
|
1650
1659
|
volcano_filename = str(alpha)+'_'+volcano_filename
|
1651
1660
|
volcano_path=os.path.join(os.path.dirname(csv_path), volcano_filename)
|
1652
1661
|
|
1662
|
+
is_normal = check_normality(df[dependent_variable], dependent_variable)
|
1663
|
+
|
1653
1664
|
if regression_type is None:
|
1654
1665
|
if is_normal:
|
1655
1666
|
regression_type = 'ols'
|
1656
1667
|
else:
|
1657
1668
|
regression_type = 'glm'
|
1658
1669
|
|
1670
|
+
#df = remove_outliers(df)
|
1671
|
+
|
1659
1672
|
if remove_row_column_effect:
|
1660
1673
|
|
1661
1674
|
## 1. Fit the initial model with row and column to estimate their effects
|
@@ -1714,7 +1727,7 @@ def regression(df, csv_path, dependent_variable='predictions', regression_type=N
|
|
1714
1727
|
model = regression_model(X, y, regression_type=regression_type, groups=groups, alpha=alpha, remove_row_column_effect=remove_row_column_effect)
|
1715
1728
|
|
1716
1729
|
# Get the model coefficients and p-values
|
1717
|
-
if regression_type in ['ols','gls','wls','rlm','glm','mixed','quantile','logit','probit','poisson'
|
1730
|
+
if regression_type in ['ols','gls','wls','rlm','glm','mixed','quantile','logit','probit','poisson']:
|
1718
1731
|
coefs = model.params
|
1719
1732
|
p_values = model.pvalues
|
1720
1733
|
|
@@ -1723,6 +1736,18 @@ def regression(df, csv_path, dependent_variable='predictions', regression_type=N
|
|
1723
1736
|
'coefficient': coefs.values,
|
1724
1737
|
'p_value': p_values.values
|
1725
1738
|
})
|
1739
|
+
elif regression_type in ['ridge', 'lasso']:
|
1740
|
+
coefs = model.coef_
|
1741
|
+
coefs = np.array(coefs).flatten()
|
1742
|
+
# Calculate p-values
|
1743
|
+
p_values = calculate_p_values(X, y, model)
|
1744
|
+
p_values = np.array(p_values).flatten()
|
1745
|
+
|
1746
|
+
# Create a DataFrame for the coefficients and p-values
|
1747
|
+
coef_df = pd.DataFrame({
|
1748
|
+
'feature': X.columns,
|
1749
|
+
'coefficient': coefs,
|
1750
|
+
'p_value': p_values})
|
1726
1751
|
else:
|
1727
1752
|
coefs = model.coef_
|
1728
1753
|
intercept = model.intercept_
|
@@ -1745,31 +1770,31 @@ def regression(df, csv_path, dependent_variable='predictions', regression_type=N
|
|
1745
1770
|
|
1746
1771
|
return model, coef_df
|
1747
1772
|
|
1748
|
-
def set_regression_defaults(settings):
|
1749
|
-
settings.setdefault('gene_weights_csv', '/nas_mnt/carruthers/Einar/mitoscreen/sequencing/combined_reads/EO1_combined/EO1_combined_combination_counts.csv')
|
1750
|
-
settings.setdefault('dependent_variable','predictions')
|
1751
|
-
settings.setdefault('transform',None)
|
1752
|
-
settings.setdefault('agg_type','mean')
|
1753
|
-
settings.setdefault('min_cell_count',25)
|
1754
|
-
settings.setdefault('regression_type','ols')
|
1755
|
-
settings.setdefault('remove_row_column_effect',False)
|
1756
|
-
settings.setdefault('alpha',1)
|
1757
|
-
settings.setdefault('fraction_threshold',0.1)
|
1758
|
-
settings.setdefault('nc','c1')
|
1759
|
-
settings.setdefault('pc','c2')
|
1760
|
-
settings.setdefault('other','c3')
|
1761
|
-
settings.setdefault('plate','plate1')
|
1762
|
-
|
1763
|
-
if settings['regression_type'] == 'quantile':
|
1764
|
-
print(f"Using alpha as quantile for quantile regression, alpha: {settings['alpha']}")
|
1765
|
-
settings['agg_type'] = None
|
1766
|
-
print(f'agg_type set to None for quantile regression')
|
1767
|
-
return settings
|
1768
|
-
|
1769
1773
|
def perform_regression(df, settings):
|
1774
|
+
|
1775
|
+
from spacr.plot import plot_plates
|
1776
|
+
from .utils import merge_regression_res_with_metadata
|
1777
|
+
from .settings import get_perform_regression_default_settings
|
1778
|
+
|
1779
|
+
reg_types = ['ols','gls','wls','rlm','glm','mixed','quantile','logit','probit','poisson','lasso','ridge']
|
1780
|
+
if settings['regression_type'] not in reg_types:
|
1781
|
+
print(f'Possible regression types: {reg_types}')
|
1782
|
+
raise ValueError(f"Unsupported regression type {settings['regression_type']}")
|
1783
|
+
|
1784
|
+
if isinstance(df, str):
|
1785
|
+
df = pd.read_csv(df)
|
1786
|
+
elif isinstance(df, pd.DataFrame):
|
1787
|
+
pass
|
1788
|
+
else:
|
1789
|
+
raise ValueError("Data must be a DataFrame or a path to a CSV file")
|
1770
1790
|
|
1771
|
-
from spacr.plot import _plot_plates
|
1772
1791
|
|
1792
|
+
if settings['dependent_variable'] not in df.columns:
|
1793
|
+
print(f'Columns in DataFrame:')
|
1794
|
+
for col in df.columns:
|
1795
|
+
print(col)
|
1796
|
+
raise ValueError(f"Dependent variable {settings['dependent_variable']} not found in the DataFrame")
|
1797
|
+
|
1773
1798
|
results_filename = os.path.splitext(os.path.basename(settings['gene_weights_csv']))[0] + '_results.csv'
|
1774
1799
|
hits_filename = os.path.splitext(os.path.basename(settings['gene_weights_csv']))[0] + '_results_significant.csv'
|
1775
1800
|
|
@@ -1781,10 +1806,22 @@ def perform_regression(df, settings):
|
|
1781
1806
|
results_path=os.path.join(os.path.dirname(settings['gene_weights_csv']), results_filename)
|
1782
1807
|
hits_path=os.path.join(os.path.dirname(settings['gene_weights_csv']), hits_filename)
|
1783
1808
|
|
1784
|
-
settings =
|
1809
|
+
settings = get_perform_regression_default_settings(settings)
|
1810
|
+
|
1811
|
+
settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
|
1812
|
+
settings_dir = os.path.dirname(settings['gene_weights_csv'])
|
1813
|
+
settings_csv = os.path.join(settings_dir,f"{settings['regression_type']}_regression_settings.csv")
|
1814
|
+
settings_df.to_csv(settings_csv, index=False)
|
1815
|
+
display(settings_df)
|
1785
1816
|
|
1786
1817
|
df = clean_controls(df,settings['pc'],settings['nc'],settings['other'])
|
1818
|
+
|
1819
|
+
if 'prediction_probability_class_1' in df.columns:
|
1820
|
+
if not settings['class_1_threshold'] is None:
|
1821
|
+
df['predictions'] = (df['prediction_probability_class_1'] >= settings['class_1_threshold']).astype(int)
|
1822
|
+
|
1787
1823
|
dependent_df, dependent_variable = process_scores(df, settings['dependent_variable'], settings['plate'], settings['min_cell_count'], settings['agg_type'], settings['transform'])
|
1824
|
+
|
1788
1825
|
display(dependent_df)
|
1789
1826
|
|
1790
1827
|
independent_df = precess_reads(settings['gene_weights_csv'], settings['fraction_threshold'], settings['plate'])
|
@@ -1793,8 +1830,9 @@ def perform_regression(df, settings):
|
|
1793
1830
|
merged_df = pd.merge(independent_df, dependent_df, on='prc')
|
1794
1831
|
|
1795
1832
|
merged_df[['plate', 'row', 'column']] = merged_df['prc'].str.split('_', expand=True)
|
1796
|
-
|
1797
|
-
|
1833
|
+
|
1834
|
+
if settings['transform'] is None:
|
1835
|
+
_ = plot_plates(df, variable=dependent_variable, grouping='mean', min_max='allq', cmap='viridis', min_count=settings['min_cell_count'])
|
1798
1836
|
|
1799
1837
|
model, coef_df = regression(merged_df, settings['gene_weights_csv'], dependent_variable, settings['regression_type'], settings['alpha'], settings['remove_row_column_effect'])
|
1800
1838
|
|
@@ -1813,6 +1851,17 @@ def perform_regression(df, settings):
|
|
1813
1851
|
print(model.summary())
|
1814
1852
|
|
1815
1853
|
significant.to_csv(hits_path, index=False)
|
1854
|
+
|
1855
|
+
me49 = '/home/carruthers/Documents/TGME49_Summary.csv'
|
1856
|
+
gt1 = '/home/carruthers/Documents/TGGT1_Summary.csv'
|
1857
|
+
|
1858
|
+
_ = merge_regression_res_with_metadata(hits_path, me49, name='_me49_metadata')
|
1859
|
+
_ = merge_regression_res_with_metadata(hits_path, gt1, name='_gt1_metadata')
|
1860
|
+
_ = merge_regression_res_with_metadata(results_path, me49, name='_me49_metadata')
|
1861
|
+
_ = merge_regression_res_with_metadata(results_path, gt1, name='_gt1_metadata')
|
1862
|
+
|
1816
1863
|
print('Significant Genes')
|
1817
1864
|
display(significant)
|
1818
1865
|
return coef_df
|
1866
|
+
|
1867
|
+
|