pelican-nlp 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. pelican_nlp/Nils_backup/__init__.py +0 -0
  2. pelican_nlp/Nils_backup/extract_acoustic_features.py +274 -0
  3. pelican_nlp/Nils_backup/fluency/__init__.py +0 -0
  4. pelican_nlp/Nils_backup/fluency/aggregate_fluency_results.py +186 -0
  5. pelican_nlp/Nils_backup/fluency/behavioral_data.py +42 -0
  6. pelican_nlp/Nils_backup/fluency/check_duplicates.py +169 -0
  7. pelican_nlp/Nils_backup/fluency/coherence.py +653 -0
  8. pelican_nlp/Nils_backup/fluency/config.py +231 -0
  9. pelican_nlp/Nils_backup/fluency/main.py +182 -0
  10. pelican_nlp/Nils_backup/fluency/optimality_without_tsa.py +466 -0
  11. pelican_nlp/Nils_backup/fluency/plot_fluency.py +573 -0
  12. pelican_nlp/Nils_backup/fluency/plotting_utils.py +170 -0
  13. pelican_nlp/Nils_backup/fluency/questionnaires_data.py +43 -0
  14. pelican_nlp/Nils_backup/fluency/stats_fluency.py +930 -0
  15. pelican_nlp/Nils_backup/fluency/utils.py +41 -0
  16. pelican_nlp/Nils_backup/speaker_diarization_Nils.py +328 -0
  17. pelican_nlp/Nils_backup/transcription/__init__.py +0 -0
  18. pelican_nlp/Nils_backup/transcription/annotation_tool.py +1001 -0
  19. pelican_nlp/Nils_backup/transcription/annotation_tool_boundaries.py +1122 -0
  20. pelican_nlp/Nils_backup/transcription/annotation_tool_sandbox.py +985 -0
  21. pelican_nlp/Nils_backup/transcription/output/holmes_control_nova_all_outputs.json +7948 -0
  22. pelican_nlp/Nils_backup/transcription/test.json +1 -0
  23. pelican_nlp/Nils_backup/transcription/transcribe_audio.py +314 -0
  24. pelican_nlp/Nils_backup/transcription/transcribe_audio_chunked.py +695 -0
  25. pelican_nlp/Nils_backup/transcription/transcription.py +801 -0
  26. pelican_nlp/Nils_backup/transcription/transcription_gui.py +955 -0
  27. pelican_nlp/Nils_backup/transcription/word_boundaries.py +190 -0
  28. pelican_nlp/Silvia_files/Opensmile/opensmile_feature_extraction.py +66 -0
  29. pelican_nlp/Silvia_files/prosogram/prosogram.py +104 -0
  30. pelican_nlp/__init__.py +1 -1
  31. pelican_nlp/_version.py +1 -0
  32. pelican_nlp/configuration_files/config_audio.yml +150 -0
  33. pelican_nlp/configuration_files/config_discourse.yml +104 -0
  34. pelican_nlp/configuration_files/config_fluency.yml +108 -0
  35. pelican_nlp/configuration_files/config_general.yml +131 -0
  36. pelican_nlp/configuration_files/config_morteza.yml +103 -0
  37. pelican_nlp/praat/__init__.py +29 -0
  38. {pelican_nlp-0.1.1.dist-info → pelican_nlp-0.1.3.dist-info}/METADATA +15 -14
  39. pelican_nlp-0.1.3.dist-info/RECORD +75 -0
  40. pelican_nlp-0.1.1.dist-info/RECORD +0 -39
  41. {pelican_nlp-0.1.1.dist-info → pelican_nlp-0.1.3.dist-info}/WHEEL +0 -0
  42. {pelican_nlp-0.1.1.dist-info → pelican_nlp-0.1.3.dist-info}/licenses/LICENSE +0 -0
  43. {pelican_nlp-0.1.1.dist-info → pelican_nlp-0.1.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,573 @@
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import os
4
+ import sys
5
+ import pandas as pd
6
+ import numpy as np
7
+ import matplotlib.pyplot as plt
8
+ import seaborn as sns
9
+ from scipy.stats import ttest_ind, f_oneway
10
+ import statsmodels.api as sm
11
+ from statsmodels.formula.api import ols
12
+ from statsmodels.stats.multicomp import pairwise_tukeyhsd
13
+ from matplotlib import cm
14
+ from sklearn.preprocessing import StandardScaler
15
+ from sklearn.linear_model import LinearRegression
16
+ from sklearn.metrics import r2_score
17
+ import statsmodels.api as sm
18
+ import statsmodels.formula.api as smf
19
+ from scipy.stats import pearsonr
20
+ from scipy import stats
21
+ from sklearn.preprocessing import StandardScaler, OneHotEncoder
22
+ from sklearn.compose import ColumnTransformer
23
+ from plotting_utils import (
24
+ COLORS as colors,
25
+ GROUP_DICT as group_dict,
26
+ NAMES as names,
27
+ METRICS as metrics,
28
+ COG_VAR as cog_var,
29
+ OUTCOME_COLORS as outcome_colors,
30
+ SORTED_OUTCOMES as sorted_outcomes,
31
+ format_p_value,
32
+ set_size
33
+ )
34
+
35
+ # Get the current PATH
36
+ original_path = os.environ.get("PATH", "")
37
+
38
+ # Specify the directory where MacTeX is installed
39
+ latex_path = "/Library/TeX/texbin"
40
+
41
+ # Prepend this path to the existing PATH
42
+ os.environ["PATH"] = latex_path + ":" + original_path
43
+
44
+ def compare_groups(df, metric, group_col, task, task_col, path):
45
+ """
46
+ Perform regression analysis to assess group differences and obtain effect sizes.
47
+
48
+ Parameters:
49
+ - df: DataFrame containing the data.
50
+ - metric: Metric to analyze.
51
+ - group_col: Column indicating the group labels.
52
+ - task: The specific task to filter data by.
53
+ - task_col: Column name indicating the task in the DataFrame.
54
+ - path: Path to save the results.
55
+ - plot: Boolean indicating whether to plot the results.
56
+
57
+ Returns:
58
+ - List of significance annotations including regression coefficients and effect sizes.
59
+ """
60
+
61
+ sig_list = []
62
+ sig_text = []
63
+ sig_list_pairwise = []
64
+ if task:
65
+ df_task = df[df[task_col] == task]
66
+ else:
67
+ df_task = df
68
+
69
+ # Prepare data
70
+ data = df_task[["study_id", "sub_task", group_col, metric]]
71
+ aggregations = {metric: "mean", group_col: "first"}
72
+ mean = df_task.groupby("study_id").agg(aggregations).reset_index()
73
+ mean["sub_task"] = "mean"
74
+ data = pd.concat([data, mean]).reset_index(drop=True).dropna()
75
+ data["sub_task"] = data["sub_task"].apply(lambda x: x.capitalize())
76
+ data[group_col] = data[group_col].replace(group_dict)
77
+ name = f"group_{metric}_{task}_{group_col}"
78
+
79
+ plt.figure(figsize=set_size(fraction=1), dpi=300)
80
+ sns.stripplot(
81
+ data=data[data["sub_task"] != "mean"],
82
+ x="sub_task",
83
+ y=metric,
84
+ hue=group_col,
85
+ jitter=True,
86
+ alpha=0.4,
87
+ dodge=True,
88
+ legend=None,
89
+ size=4,
90
+ palette=colors,
91
+ )
92
+ sns.pointplot(
93
+ data=data[data["sub_task"] != "mean"],
94
+ x="sub_task",
95
+ y=metric,
96
+ hue=group_col,
97
+ capsize=0.2,
98
+ markers="s",
99
+ errorbar='ci',
100
+ dodge=0.4 if group_col == "class" else 0.55,
101
+ palette=colors,
102
+ )
103
+
104
+ plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.1), title = "Group means with 95\% CI",
105
+ ncol=4, frameon = True)
106
+
107
+ plt.ylabel(names[metric])
108
+ plt.xlabel("")
109
+ plt.savefig(path + name + ".tiff")
110
+ plt.show()
111
+
112
+ return sig_list, sig_list_pairwise
113
+
114
+
115
+ def compare_groups_single_plot_regression(df, metrics, group_col, path=None, scale_data=True, n_comp=1):
116
+ """
117
+ Compare groups across multiple metrics, perform regression analysis, and create a single plot with significance stars.
118
+
119
+ Parameters:
120
+ - df: DataFrame containing the data.
121
+ - metrics: List of metrics to analyze.
122
+ - group_col: Column indicating the group labels.
123
+ - path: Optional path to save the plot.
124
+ - scale_data: Boolean indicating whether to scale the metrics.
125
+ - pairwise: Boolean indicating whether to perform post-hoc pairwise comparisons.
126
+
127
+ Returns:
128
+ - List of significance annotations and pairwise comparisons if requested.
129
+ """
130
+
131
+ # Initialize lists for storing results
132
+ sig_list = []
133
+ sig_list_pairwise = []
134
+
135
+ # Scale the data if specified
136
+ df_scaled = df.copy()
137
+ if scale_data:
138
+ scaler = StandardScaler()
139
+ df_scaled[metrics] = scaler.fit_transform(df[metrics])
140
+
141
+ # Replace group column values using group_dict (ensure group_dict is defined)
142
+ df_scaled[group_col] = df_scaled[group_col].replace(group_dict)
143
+
144
+ # Melt the DataFrame for easier plotting
145
+ df_melted = df_scaled.melt(id_vars=[group_col], value_vars=metrics, var_name='Metric', value_name='Value')
146
+ df_melted['Metric'] = df_melted['Metric'].replace(names)
147
+
148
+ # Set up the plot
149
+ plt.figure(figsize=set_size(fraction=1), dpi=300)
150
+
151
+ # Create a point plot with error bars
152
+ sns.pointplot(
153
+ data=df_melted,
154
+ x='Metric',
155
+ y='Value',
156
+ hue=group_col,
157
+ palette=colors,
158
+ capsize=0.2,
159
+ markers="s",
160
+ errorbar='ci',
161
+ dodge=0.4 if group_col == "class" else 0.55,
162
+ )
163
+
164
+ # Add jittered dots to show the individual data points
165
+ sns.stripplot(
166
+ data=df_melted,
167
+ x='Metric',
168
+ y='Value',
169
+ hue=group_col,
170
+ dodge=True,
171
+ legend=None,
172
+ palette=colors,
173
+ marker='o',
174
+ edgecolor='gray',
175
+ alpha=0.5
176
+ )
177
+
178
+ for metric in metrics:
179
+ # Perform regression analysis
180
+ formula = f"{metric} ~ C({group_col})"
181
+ model = smf.ols(formula, data=df_scaled).fit()
182
+ p_value = model.f_pvalue # Get the overall p-value for the group effect
183
+ r_squared = model.rsquared
184
+
185
+ # Extract p-values and coefficients
186
+ p_values = model.pvalues
187
+ coefs = model.params
188
+
189
+ # Group statistics: means and standard deviations
190
+ grouped_data_stats = {}
191
+ for g in df_scaled[group_col].unique():
192
+ grouped_data_stats[f'{g} mean'] = df_scaled[df_scaled[group_col] == g][metric].mean()
193
+ grouped_data_stats[f'{g} sd'] = df_scaled[df_scaled[group_col] == g][metric].std()
194
+
195
+ # Store regression results
196
+ for group in p_values.index:
197
+ sig_list.append(
198
+ {
199
+ "metric": names[metric],
200
+ "group": group,
201
+ "coef": coefs[group],
202
+ "p": p_values[group],
203
+ "r_squared": r_squared,
204
+ "f_stat": model.fvalue,
205
+ "p_model": p_value,
206
+ "df": f"(1, {int(model.df_resid)})"
207
+ }| grouped_data_stats
208
+ )
209
+
210
+
211
+
212
+ if p_value < 0.05/n_comp:
213
+ pairwise_results = pairwise_tukeyhsd(df_scaled[metric], df_scaled[group_col], alpha=0.05)
214
+ for comparison in pairwise_results.summary().data[1:]:
215
+ group1, group2, mean_diff, p_adj, _, _, _ = comparison
216
+ sig_list_pairwise.append(
217
+ {
218
+ "metric": names[metric],
219
+ "group": f"{group1} vs {group2}",
220
+ "mean_diff": mean_diff,
221
+ "p": format_p_value(p_adj),
222
+ }
223
+ )
224
+
225
+ # Add annotation to the plot for p-value and R²
226
+ x_pos = metrics.index(metric)
227
+ plt.annotate(f"{format_p_value(p_value)}", xy=(x_pos, 0.82), xycoords=('data', 'axes fraction'),
228
+ ha='center', va='bottom', color='black')
229
+ plt.annotate(f"R²={r_squared:.2f}", xy=(x_pos, 0.88), xycoords=('data', 'axes fraction'),
230
+ ha='center', va='bottom',color='black')
231
+
232
+ # Customize the plot
233
+ plt.ylabel('Values' if not scale_data else 'z-scores')
234
+ plt.xlabel("")
235
+ plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.1), ncol=3, frameon=True, title="Group means with 95\% CI")
236
+ plt.tight_layout()
237
+
238
+ # Save the plot if a path is provided
239
+ if path:
240
+ plt.savefig(f"{path}/group_comparisons.png", dpi=300)
241
+
242
+ # Show the plot
243
+ plt.show()
244
+
245
+ return sig_list, sig_list_pairwise
246
+
247
+
248
+ def plot_regression_coefficients(coefficients, ci_lower, ci_upper, r2_values_df, title, path):
249
+ # Create a DataFrame for plotting
250
+ metrics = coefficients.index.to_list()
251
+ outcomes = coefficients.columns.to_list()
252
+ outcomes = sorted(outcomes, key=lambda x: sorted_outcomes.index(x))
253
+
254
+ df_results = pd.DataFrame({
255
+ 'Metric': np.repeat(metrics, len(outcomes)),
256
+ 'Outcome': np.tile(outcomes, len(metrics)),
257
+ 'Coefficient': coefficients.values.flatten(),
258
+ 'CI Lower': ci_lower.values.flatten(),
259
+ 'CI Upper': ci_upper.values.flatten()
260
+ })
261
+
262
+ # Calculate the error (half the confidence interval width)
263
+ df_results['Error Lower'] = (df_results['Coefficient'] - df_results['CI Lower'])
264
+ df_results['Error Upper'] = (df_results['CI Upper'] - df_results['Coefficient'])
265
+
266
+ # Set up the plot with increased figure size
267
+ plt.figure(figsize=set_size())
268
+
269
+ # Parameters for spacing
270
+ n_metrics = len(metrics)
271
+ n_outcomes = len(outcomes)
272
+
273
+ group_spacing = len(outcomes) * 0.28 # Space between different metrics
274
+ bar_width = 0.18 # Space between bars within the same metric
275
+
276
+ # Calculate positions for each metric and outcome
277
+ x_positions = []
278
+ for i in range(n_metrics):
279
+ base_pos = i * group_spacing
280
+ for j in range(n_outcomes):
281
+ x_positions.append(base_pos + j * bar_width)
282
+
283
+ # Plot each outcome separately
284
+ for i, outcome in enumerate(outcomes):
285
+ subset = df_results[df_results['Outcome'] == outcome]
286
+
287
+ # Extract positions for this outcome
288
+ x_pos_subset = x_positions[i::n_outcomes]
289
+
290
+ # Plot points with error bars
291
+ plt.errorbar(
292
+ x_pos_subset,
293
+ subset['Coefficient'],
294
+ yerr=[subset['Error Lower'], subset['Error Upper']],
295
+ fmt='s',
296
+ label=outcome,
297
+ color=outcome_colors[outcome],
298
+ capsize=6,
299
+ elinewidth=4,
300
+ markeredgewidth=4,
301
+ markersize=5
302
+ )
303
+ # Add a horizontal line at zero
304
+ plt.axhline(y=0, color='black', linestyle='--', linewidth=1)
305
+ # Customize plot labels and title
306
+ plt.xlabel(f"Dependent Variable {title}", fontweight='bold')
307
+ plt.ylabel("Standardized Coefficient", fontweight='bold')
308
+
309
+ metric_ticks = [i * group_spacing + (n_outcomes - 1) * bar_width / 2 for i in range(n_metrics)]
310
+ metric_labels = []
311
+ for metric in metrics:
312
+ model_r2_value = r2_values_df.loc[metric]
313
+ metric_labels.append(f"{metric}\n" + "(${R}^2$" + f' = {model_r2_value:.3f})')
314
+
315
+ plt.xticks(ticks=metric_ticks, labels=metric_labels)
316
+
317
+ # Adjust legend
318
+ plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.30),
319
+ ncol=3, frameon=True, title=f"Independent Variables {title}")
320
+
321
+ # Increase spacing around plot elements
322
+ plt.tight_layout(pad=2.0)
323
+
324
+ # Save the plot
325
+ plt.savefig(f"{path}/multivar_reg_{title}.png", dpi=300)
326
+
327
+ # Show the plot
328
+ plt.show()
329
+
330
+ def scatterplot_matrix_with_corr(df, columns, hspace=0.1, wspace=0.1):
331
+ """
332
+ Create a 2D matrix of scatterplots for each pair of columns in the DataFrame.
333
+ Annotate each plot with the Pearson correlation coefficient.
334
+ """
335
+ num_cols = len(columns)
336
+
337
+ fig, axes = plt.subplots(nrows=num_cols, ncols=num_cols, figsize=(15, 15))
338
+ fig.subplots_adjust(hspace=hspace, wspace=wspace)
339
+
340
+ for i, col1 in enumerate(columns):
341
+ for j, col2 in enumerate(columns):
342
+ ax = axes[i, j]
343
+
344
+ if i == j:
345
+ # Diagonal: Show the histogram
346
+ ax.hist(df[col1], bins=20, color='lightblue')
347
+ ax.set_title("")
348
+ ax.grid(False)
349
+ else:
350
+ # Scatter plot and correlation coefficient
351
+ ax.scatter(df[col2], df[col1], alpha=0.6)
352
+ ax.grid(False)
353
+ # Calculate the correlation coefficient
354
+ corr_coeff = pearsonr(df[col2], df[col1])[0]
355
+
356
+ # Annotate the correlation coefficient
357
+ ax.annotate(f'r = {corr_coeff:.2f}', xy=(0.5, 0.85), xycoords='axes fraction',
358
+ ha='center', color='red', fontsize = 17)
359
+
360
+ # Only show x labels on the bottom row and y labels on the left column
361
+ if i < num_cols - 1:
362
+ ax.set_xticklabels([])
363
+ else:
364
+ ax.set_xlabel(names[col2])
365
+
366
+ if j > 0:
367
+ ax.set_yticklabels([])
368
+ else:
369
+ ax.set_ylabel(names[col1])
370
+
371
+ plt.show()
372
+
373
+
374
+
375
+ def calculate_residuals(df, col, control_vars):
376
+ """
377
+ Regress a single column on control variables and return the residuals.
378
+
379
+ Parameters:
380
+ df (pd.DataFrame): The input DataFrame.
381
+ col (str): The column for which to calculate residuals.
382
+ control_vars (list): The list of control variables (both numeric and categorical).
383
+
384
+ Returns:
385
+ pd.Series: The residuals for the column after regressing out the control variables.
386
+ """
387
+ # Identify numeric and categorical control variables
388
+ numeric_features = df[control_vars].select_dtypes(include=[np.number]).columns.tolist()
389
+ categorical_features = df[control_vars].select_dtypes(exclude=[np.number]).columns.tolist()
390
+
391
+ # Preprocessing pipeline: scale numeric, one-hot encode categorical
392
+ transformers = [('num', StandardScaler(), numeric_features)]
393
+ if categorical_features:
394
+ transformers.append(('cat', OneHotEncoder(drop='first', sparse=False), categorical_features))
395
+
396
+ preprocessor = ColumnTransformer(transformers, remainder='passthrough')
397
+
398
+ # Ensure the column is numeric and convert it
399
+ df[col] = pd.to_numeric(df[col], errors='coerce')
400
+
401
+ # Preprocess control variables
402
+ X = preprocessor.fit_transform(df[control_vars])
403
+ X = sm.add_constant(X) # Add constant (intercept)
404
+
405
+ # Regress col on control variables and get residuals
406
+ model = sm.OLS(df[col], X).fit()
407
+
408
+ return model.resid # Return the residuals
409
+
410
+ def scatterplot_matrix_with_partial_corr(df, columns1, columns2, control_vars, hspace=0.1, wspace=0.1):
411
+ """
412
+ Create a 2D matrix of scatterplots using residuals for each pair of columns in the DataFrame,
413
+ while controlling for background variables (including categorical variables).
414
+ Annotate each plot with the partial correlation coefficient.
415
+
416
+ Parameters:
417
+ df (pd.DataFrame): The input DataFrame.
418
+ columns (list): List of column names to plot.
419
+ control_vars (list): List of control variables to control for (can include categorical variables).
420
+ hspace (float): The amount of height reserved for space between subplots, expressed as a fraction of the average axis height.
421
+ wspace (float): The amount of width reserved for space between subplots, expressed as a fraction of the average axis width.
422
+ """
423
+ num_cols = len(columns2)
424
+ num_rows = len(columns1)
425
+
426
+ fig, axes = plt.subplots(nrows=num_rows, ncols=num_cols, figsize=(15, 15))
427
+ fig.subplots_adjust(hspace=hspace, wspace=wspace)
428
+
429
+ residuals = {} # Dictionary to store residuals for each column
430
+
431
+ # Calculate residuals for each column
432
+ for col in list(set(columns1 + columns2)):
433
+ residuals[col] = calculate_residuals(df, col, control_vars)
434
+
435
+ for i, col1 in enumerate(columns1):
436
+ for j, col2 in enumerate(columns2):
437
+ ax = axes[i, j]
438
+
439
+ # Scatter plot of residuals and partial correlation coefficient
440
+ ax.scatter(residuals[col2], residuals[col1], alpha=0.8, s = 10)
441
+
442
+ # Calculate the partial correlation
443
+ try:
444
+ partial_corr, _ = pearsonr(residuals[col2], residuals[col1])
445
+ except Exception as e:
446
+ print(f"Error calculating partial correlation for {col1} and {col2}: {e}")
447
+ partial_corr = np.nan
448
+
449
+ # Annotate the partial correlation coefficient
450
+ ax.annotate(f'pcorr = {partial_corr:.2f}', xy=(0.5, 0.85), xycoords='axes fraction',
451
+ ha='center', fontsize=17, color='red')
452
+ ax.grid(False)
453
+
454
+ # Only show x labels on the bottom row
455
+ if i < num_rows - 1:
456
+ ax.set_xticklabels([])
457
+ else:
458
+ ax.set_xlabel(f'{names[col2]}')
459
+
460
+ # Only show y labels on the leftmost column
461
+ if j > 0:
462
+ ax.set_yticklabels([])
463
+ else:
464
+ ax.set_ylabel(f'{names[col1]}')
465
+
466
+ plt.show()
467
+
468
+
469
+ def plot_stepwise_regression_results(df, group, path, title):
470
+ """
471
+ Plots the R-squared values for the base model and the full models,
472
+ labeled by the metric added and annotated with p-values.
473
+ """
474
+ # Extract the relevant data
475
+ base_r2 = df['r2_adj_control'].iloc[0] # Assuming the base model R-squared is the same across all rows
476
+ metrics = df['metric']
477
+ r2_full = df['r2_adj_full']
478
+ p_values = df['p_value']
479
+ controls = df["control"].iloc[0]
480
+ score = names[df["psychiatric_score"].iloc[0]]
481
+ control_string = " + ".join([names[control] for control in controls])
482
+
483
+ # Plotting
484
+ plt.figure(figsize = set_size(), dpi = 300) # Adjust figure size for a more compact plot
485
+
486
+ # Plot the full models R-squared
487
+ plt.bar(metrics, r2_full, color='deepskyblue', width=0.4, label=f'Control Model + NLP Metric $\sim$ {score}' )
488
+
489
+ # Plot the base model R-squared
490
+ plt.bar(metrics, [base_r2] * len(metrics), color='gray', width=0.4, label=f'{control_string} $\sim$ {score}')
491
+
492
+ # Annotate the p-values on the full models
493
+ for i, p_val in enumerate(p_values):
494
+ plt.text(i, r2_full[i] + 0.02, f'p={p_val:.3f}', ha='center', color='black')
495
+
496
+ # Add labels and title
497
+ plt.ylabel('${R}^2$ adj.' + f' {score} {group}')
498
+
499
+ plt.ylim(0, max(r2_full) + 0.05) # Adjust y-axis limits
500
+
501
+ # Get the current x-tick labels and positions
502
+ x_ticks = plt.xticks()
503
+ x_tick_labels = x_ticks[1] # Extract the labels
504
+
505
+ # Create a list of translated labels
506
+ xticklabels = [names.get(tick.get_text(), tick.get_text()) for tick in x_tick_labels] # Translate using the 'names' dictionary
507
+
508
+ # Apply the translated labels back to the x-ticks
509
+ plt.xticks(ticks=x_ticks[0], labels=xticklabels)
510
+ # Add legend
511
+ plt.legend(loc='upper right', frameon = True)
512
+
513
+ # Make layout more compact
514
+ plt.tight_layout()
515
+ plt.savefig(f"{path}/stepwise_reg_{title}.png", dpi=300)
516
+ # Show plot
517
+ plt.show()
518
+
519
+ def plot_regression_with_levels(df, x_col, y_col_prefix, feature_suffix, title, path):
520
+ """
521
+ Plots a scatter plot with regression lines for each level of a third variable
522
+ that is embedded in the column names. Also plots a secondary y-axis showing
523
+ the relationship between the x_col and response length.
524
+ """
525
+ # Initialize the plot
526
+ fig, ax1 = plt.subplots(figsize=set_size(), dpi = 300)
527
+
528
+ # Extract columns that match the y_col_prefix and feature_suffix
529
+ matching_columns = [col for col in df.columns if col.startswith(y_col_prefix) and col.endswith(feature_suffix)]
530
+
531
+ # Extract the unique level values from the matching columns
532
+ levels = [col.split('_')[2] for col in matching_columns]
533
+
534
+ # Define a color palette with enough colors for each level
535
+ palette = sns.color_palette("colorblind", len(levels))
536
+
537
+ # Iterate over each level, fit a regression, and plot
538
+ for i, col in enumerate(matching_columns):
539
+ level = col.split('_')[2]
540
+
541
+ # Filter out rows where y is NaN
542
+ valid_data = df[[x_col, col]].dropna()
543
+ if valid_data.empty:
544
+ continue
545
+
546
+ X = valid_data[[x_col]]
547
+ y = valid_data[col]
548
+
549
+ # Fit a linear regression model
550
+ model = LinearRegression()
551
+ model.fit(X, y)
552
+ y_pred = model.predict(X)
553
+
554
+ # Calculate the R² value
555
+ r2 = r2_score(y, y_pred)
556
+
557
+ # Plot the data points
558
+ ax1.scatter(X, y, color=palette[i], label=None, alpha=0.6, s = 10)
559
+
560
+ # Plot the regression line
561
+ ax1.plot(X, y_pred, color=palette[i],
562
+ label=f'{names[y_col_prefix+level+"_"+feature_suffix]} (R² = {r2:.2f})', linewidth = 3)
563
+
564
+ # Adding labels and title for the first y-axis
565
+ ax1.set_xlabel(names[x_col])
566
+ ax1.set_ylabel(f'{names[y_col_prefix]}')
567
+
568
+ # Add legend
569
+ ax1.legend(loc = "upper right")
570
+ plt.savefig(f"{path}/synthetic_p_{feature_suffix}_with_length.png", dpi=300)
571
+
572
+ # Show plot
573
+ plt.show()