gsMap3D 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. gsMap/__init__.py +13 -0
  2. gsMap/__main__.py +4 -0
  3. gsMap/cauchy_combination_test.py +342 -0
  4. gsMap/cli.py +355 -0
  5. gsMap/config/__init__.py +72 -0
  6. gsMap/config/base.py +296 -0
  7. gsMap/config/cauchy_config.py +79 -0
  8. gsMap/config/dataclasses.py +235 -0
  9. gsMap/config/decorators.py +302 -0
  10. gsMap/config/find_latent_config.py +276 -0
  11. gsMap/config/format_sumstats_config.py +54 -0
  12. gsMap/config/latent2gene_config.py +461 -0
  13. gsMap/config/ldscore_config.py +261 -0
  14. gsMap/config/quick_mode_config.py +242 -0
  15. gsMap/config/report_config.py +81 -0
  16. gsMap/config/spatial_ldsc_config.py +334 -0
  17. gsMap/config/utils.py +286 -0
  18. gsMap/find_latent/__init__.py +3 -0
  19. gsMap/find_latent/find_latent_representation.py +312 -0
  20. gsMap/find_latent/gnn/distribution.py +498 -0
  21. gsMap/find_latent/gnn/encoder_decoder.py +186 -0
  22. gsMap/find_latent/gnn/gcn.py +85 -0
  23. gsMap/find_latent/gnn/gene_former.py +164 -0
  24. gsMap/find_latent/gnn/loss.py +18 -0
  25. gsMap/find_latent/gnn/st_model.py +125 -0
  26. gsMap/find_latent/gnn/train_step.py +177 -0
  27. gsMap/find_latent/st_process.py +781 -0
  28. gsMap/format_sumstats.py +446 -0
  29. gsMap/generate_ldscore.py +1018 -0
  30. gsMap/latent2gene/__init__.py +18 -0
  31. gsMap/latent2gene/connectivity.py +781 -0
  32. gsMap/latent2gene/entry_point.py +141 -0
  33. gsMap/latent2gene/marker_scores.py +1265 -0
  34. gsMap/latent2gene/memmap_io.py +766 -0
  35. gsMap/latent2gene/rank_calculator.py +590 -0
  36. gsMap/latent2gene/row_ordering.py +182 -0
  37. gsMap/latent2gene/row_ordering_jax.py +159 -0
  38. gsMap/ldscore/__init__.py +1 -0
  39. gsMap/ldscore/batch_construction.py +163 -0
  40. gsMap/ldscore/compute.py +126 -0
  41. gsMap/ldscore/constants.py +70 -0
  42. gsMap/ldscore/io.py +262 -0
  43. gsMap/ldscore/mapping.py +262 -0
  44. gsMap/ldscore/pipeline.py +615 -0
  45. gsMap/pipeline/quick_mode.py +134 -0
  46. gsMap/report/__init__.py +2 -0
  47. gsMap/report/diagnosis.py +375 -0
  48. gsMap/report/report.py +100 -0
  49. gsMap/report/report_data.py +1832 -0
  50. gsMap/report/static/js_lib/alpine.min.js +5 -0
  51. gsMap/report/static/js_lib/tailwindcss.js +83 -0
  52. gsMap/report/static/template.html +2242 -0
  53. gsMap/report/three_d_combine.py +312 -0
  54. gsMap/report/three_d_plot/three_d_plot_decorate.py +246 -0
  55. gsMap/report/three_d_plot/three_d_plot_prepare.py +202 -0
  56. gsMap/report/three_d_plot/three_d_plots.py +425 -0
  57. gsMap/report/visualize.py +1409 -0
  58. gsMap/setup.py +5 -0
  59. gsMap/spatial_ldsc/__init__.py +0 -0
  60. gsMap/spatial_ldsc/io.py +656 -0
  61. gsMap/spatial_ldsc/ldscore_quick_mode.py +912 -0
  62. gsMap/spatial_ldsc/spatial_ldsc_jax.py +382 -0
  63. gsMap/spatial_ldsc/spatial_ldsc_multiple_sumstats.py +439 -0
  64. gsMap/utils/__init__.py +0 -0
  65. gsMap/utils/generate_r2_matrix.py +610 -0
  66. gsMap/utils/jackknife.py +518 -0
  67. gsMap/utils/manhattan_plot.py +643 -0
  68. gsMap/utils/regression_read.py +177 -0
  69. gsMap/utils/torch_utils.py +23 -0
  70. gsmap3d-0.1.0a1.dist-info/METADATA +168 -0
  71. gsmap3d-0.1.0a1.dist-info/RECORD +74 -0
  72. gsmap3d-0.1.0a1.dist-info/WHEEL +4 -0
  73. gsmap3d-0.1.0a1.dist-info/entry_points.txt +2 -0
  74. gsmap3d-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,1409 @@
1
+ import gc
2
+ import logging
3
+ import warnings
4
+ from concurrent.futures import ThreadPoolExecutor
5
+ from pathlib import Path
6
+ from typing import Literal
7
+
8
+ import distinctipy
9
+ import matplotlib
10
+ import matplotlib.axes
11
+ import matplotlib.colors as mcolors
12
+ import matplotlib.pyplot as plt
13
+ import numpy as np
14
+ import pandas as pd
15
+ import plotly.express as px
16
+ import plotly.graph_objects as go
17
+ import scipy.stats as stats
18
+ from rich import print
19
+ from scipy.cluster.hierarchy import leaves_list, linkage
20
+ from scipy.spatial import KDTree
21
+ from tqdm import tqdm
22
+
23
+ warnings.filterwarnings("ignore")
24
+
25
+
26
+ def remove_outliers_MAD(data, threshold=3.5):
27
+ """
28
+ Remove outliers based on Median Absolute Deviation (MAD).
29
+ """
30
+ if isinstance(data, pd.Series | pd.DataFrame):
31
+ data_values = data.values.flatten()
32
+ else:
33
+ data_values = np.asarray(data).flatten()
34
+
35
+ if len(data_values) == 0:
36
+ return data, np.ones(len(data), dtype=bool)
37
+
38
+ median = np.nanmedian(data_values)
39
+ mad = np.nanmedian(np.abs(data_values - median))
40
+ if mad == 0:
41
+ return data, np.ones(len(data), dtype=bool)
42
+
43
+ modified_z_score = 0.6745 * (data_values - median) / mad
44
+ mask = np.abs(modified_z_score) <= threshold
45
+
46
+ if isinstance(data, pd.Series):
47
+ return data[mask], mask
48
+ elif isinstance(data, np.ndarray):
49
+ if len(data.shape) == 1:
50
+ return data[mask], mask
51
+ else:
52
+ return data.flatten()[mask], mask
53
+ return data[mask], mask
54
+
55
+
56
+ def load_ldsc(ldsc_input_file):
57
+ ldsc = pd.read_csv(
58
+ ldsc_input_file,
59
+ compression="gzip",
60
+ dtype={"spot": str, "p": float},
61
+ index_col="spot",
62
+ usecols=["spot", "p"],
63
+ )
64
+ ldsc["logp"] = -np.log10(ldsc.p)
65
+ return ldsc
66
+
67
+
68
+ # %%
69
+ def load_st_coord(adata, feature_series: pd.Series, annotation):
70
+ spot_name = adata.obs_names.to_list()
71
+ assert "spatial" in adata.obsm.keys(), "spatial coordinates are not found in adata.obsm"
72
+
73
+ # to DataFrame
74
+ space_coord = adata.obsm["spatial"]
75
+ if isinstance(space_coord, np.ndarray):
76
+ space_coord = pd.DataFrame(space_coord, columns=["sx", "sy"], index=spot_name)
77
+ else:
78
+ space_coord = pd.DataFrame(space_coord.values, columns=["sx", "sy"], index=spot_name)
79
+
80
+ space_coord = space_coord[space_coord.index.isin(feature_series.index)]
81
+ space_coord_concat = pd.concat([space_coord.loc[feature_series.index], feature_series], axis=1)
82
+ space_coord_concat.head()
83
+ if annotation is not None:
84
+ annotation = pd.Series(
85
+ adata.obs[annotation].values, index=adata.obs_names, name="annotation"
86
+ )
87
+ space_coord_concat = pd.concat([space_coord_concat, annotation], axis=1)
88
+ return space_coord_concat
89
+
90
+
91
+ def estimate_plotly_point_size(coordinates, DEFAULT_PIXEL_WIDTH=1000):
92
+ # Convert to numpy array if it's a DataFrame or other array-like object
93
+ if hasattr(coordinates, 'values'):
94
+ coordinates = coordinates.values
95
+ coordinates = np.asarray(coordinates)
96
+
97
+ tree = KDTree(coordinates)
98
+ distances, _ = tree.query(coordinates, k=2)
99
+ avg_min_distance = np.median(distances[:, 1])
100
+ # get the width and height of the plot
101
+ width = np.max(coordinates[:, 0]) - np.min(coordinates[:, 0])
102
+ height = np.max(coordinates[:, 1]) - np.min(coordinates[:, 1])
103
+
104
+ scale_factor = DEFAULT_PIXEL_WIDTH / max(width, height)
105
+ pixel_width = width * scale_factor
106
+ pixel_height = height * scale_factor
107
+
108
+ point_size = avg_min_distance * scale_factor
109
+
110
+ return (pixel_width, pixel_height), point_size
111
+
112
+
113
+ def estimate_matplotlib_scatter_marker_size(ax: matplotlib.axes.Axes, coordinates: np.ndarray,
114
+ x_limits: tuple | None = None,
115
+ y_limits: tuple | None = None) -> float:
116
+ """
117
+ Estimates the appropriate marker size to make adjacent markers touch.
118
+
119
+ This function calculates the size 's' for a square marker (in points^2)
120
+ such that its diameter in the plot corresponds to the average distance
121
+ to the nearest neighbor for each point in the dataset. It accounts for
122
+ the plot's aspect ratio and final rendered dimensions.
123
+
124
+ Args:
125
+ ax (matplotlib.axes.Axes): The subplot object. The function will
126
+ temporarily set its limits and aspect ratio to ensure the
127
+ transformation from data units to display units is accurate.
128
+ coordinates (np.ndarray): A NumPy array of shape (n, 2)
129
+ containing the (x, y) coordinates of the points.
130
+ x_limits (Optional[tuple]): Optional (min, max) tuple to override automatic x-axis limits.
131
+ y_limits (Optional[tuple]): Optional (min, max) tuple to override automatic y-axis limits.
132
+
133
+ Returns:
134
+ float: The estimated marker size 's' (in points^2) for use with
135
+ ax.scatter().
136
+ """
137
+ # 1. Set up the axes' properties to ensure accurate transformations.
138
+ # The aspect ratio and data limits must be set to correctly
139
+ # calculate the relationship between data units and display units (inches/points).
140
+ ax.set_aspect('equal')
141
+
142
+ # Use provided limits if available, otherwise calculate from data
143
+ if x_limits is not None:
144
+ x_data_min, x_data_max = x_limits
145
+ else:
146
+ x_data_min, x_data_max = np.min(coordinates[:, 0]), np.max(coordinates[:, 0])
147
+
148
+ if y_limits is not None:
149
+ y_data_min, y_data_max = np.min(coordinates[:, 1]), np.max(coordinates[:, 1])
150
+ else:
151
+ y_data_min, y_data_max = np.min(coordinates[:, 1]), np.max(coordinates[:, 1])
152
+
153
+ ax.set_xlim(x_data_min, x_data_max)
154
+ ax.set_ylim(y_data_min, y_data_max)
155
+
156
+ # Force a draw of the canvas to finalize the transformations.
157
+ ax.figure.canvas.draw()
158
+
159
+ # 2. Calculate the required marker radius in data units.
160
+ # We find the average distance to the nearest neighbor for all points.
161
+ # The desired radius is half of this distance.
162
+ tree = KDTree(coordinates)
163
+ distances, _ = tree.query(coordinates, k=2)
164
+ radius_data = np.mean(distances[:, 1]) / 2
165
+
166
+ # 3. Convert the data radius to display units (points).
167
+ # This requires transforming the axes' bounding box from data coordinates
168
+ # to display coordinates (pixels), then to physical units (inches).
169
+
170
+ # Get the bounding box in display (pixel) coordinates
171
+ x_display_min, _ = ax.transData.transform((x_data_min, y_data_min))
172
+ x_display_max, _ = ax.transData.transform((x_data_max, y_data_max))
173
+
174
+ # Convert the display coordinates to inches
175
+ x_inch_min, _ = ax.figure.dpi_scale_trans.inverted().transform((x_display_min, 0))
176
+ x_inch_max, _ = ax.figure.dpi_scale_trans.inverted().transform((x_display_max, 0))
177
+
178
+ width_inch = x_inch_max - x_inch_min
179
+ width_data = x_data_max - x_data_min
180
+
181
+ # Calculate the radius in inches. This scales the data radius by the
182
+ # ratio of the plot's physical width to its data width.
183
+ # This works because the aspect ratio is 'equal'.
184
+ radius_inch = (radius_data / width_data) * width_inch
185
+
186
+ # Convert inches to points (1 inch = 72 points).
187
+ radius_points = radius_inch * 72
188
+
189
+ # 4. Calculate the marker size 's'.
190
+ # For ax.scatter, 's' is the marker area in points^2.
191
+ # For a square marker, the area is (side)^2, where side = 2 * radius.
192
+ square_marker_size = (2 * radius_points) ** 2
193
+
194
+ return square_marker_size * 1.2
195
+
196
+
197
+ def draw_scatter(
198
+ space_coord_concat,
199
+ title=None,
200
+ fig_style: Literal["dark", "light"] = "light",
201
+ point_size: int = None,
202
+ width=800,
203
+ height=600,
204
+ annotation=None,
205
+ color_by="logp",
206
+ color_continuous_scale=None,
207
+ plot_origin="upper",
208
+ ):
209
+ # Set theme based on fig_style
210
+ if fig_style == "dark":
211
+ px.defaults.template = "plotly_dark"
212
+ else:
213
+ px.defaults.template = "plotly_white"
214
+
215
+ if color_continuous_scale is None:
216
+ custom_color_scale = [
217
+ (1, "#d73027"), # Red
218
+ (7 / 8, "#f46d43"), # Red-Orange
219
+ (6 / 8, "#fdae61"), # Orange
220
+ (5 / 8, "#fee090"), # Light Orange
221
+ (4 / 8, "#e0f3f8"), # Light Blue
222
+ (3 / 8, "#abd9e9"), # Sky Blue
223
+ (2 / 8, "#74add1"), # Medium Blue
224
+ (1 / 8, "#4575b4"), # Dark Blue
225
+ (0, "#313695"), # Deep Blue
226
+ ]
227
+ custom_color_scale.reverse()
228
+ color_continuous_scale = custom_color_scale
229
+
230
+ # Create the scatter plot
231
+ fig = px.scatter(
232
+ space_coord_concat,
233
+ x="sx",
234
+ y="sy",
235
+ color=color_by,
236
+ symbol="annotation" if annotation is not None else None,
237
+ title=title,
238
+ color_continuous_scale=color_continuous_scale,
239
+ range_color=[0, max(space_coord_concat[color_by])],
240
+ )
241
+
242
+ # Update marker size if specified
243
+ if point_size is not None:
244
+ fig.update_traces(marker=dict(size=point_size, symbol="circle"))
245
+
246
+ # Update layout for figure size
247
+ fig.update_layout(
248
+ autosize=False,
249
+ width=width,
250
+ height=height,
251
+ )
252
+
253
+ # Adjusting the legend
254
+ fig.update_layout(
255
+ legend=dict(
256
+ yanchor="top",
257
+ y=0.95,
258
+ xanchor="left",
259
+ x=1.0,
260
+ font=dict(
261
+ size=10,
262
+ ),
263
+ )
264
+ )
265
+
266
+ # Update colorbar to be at the bottom and horizontal
267
+ fig.update_layout(
268
+ coloraxis_colorbar=dict(
269
+ orientation="h", # Make the colorbar horizontal
270
+ x=0.5, # Center the colorbar horizontally
271
+ y=-0.0, # Position below the plot
272
+ xanchor="center", # Anchor the colorbar at the center
273
+ yanchor="top", # Anchor the colorbar at the top to keep it just below the plot
274
+ len=0.75, # Length of the colorbar relative to the plot width
275
+ title=dict(
276
+ text="-log10(p)" if color_by == "logp" else color_by, # Colorbar title
277
+ side="top", # Place the title at the top of the colorbar
278
+ ),
279
+ )
280
+ )
281
+ # Remove gridlines, axis labels, and ticks
282
+ fig.update_xaxes(
283
+ showgrid=False, # Hide x-axis gridlines
284
+ zeroline=False, # Hide x-axis zero line
285
+ showticklabels=False, # Hide x-axis tick labels
286
+ title=None, # Remove x-axis title
287
+ scaleanchor="y", # Link the x-axis scale to the y-axis scale
288
+ )
289
+
290
+ fig.update_yaxes(
291
+ showgrid=False, # Hide y-axis gridlines
292
+ zeroline=False, # Hide y-axis zero line
293
+ showticklabels=False, # Hide y-axis tick labels
294
+ title=None, # Remove y-axis title
295
+ autorange="reversed" if plot_origin == "upper" else True,
296
+ )
297
+
298
+ # Adjust margins to ensure no clipping and equal axis ratio
299
+ fig.update_layout(
300
+ margin=dict(l=0, r=0, t=20, b=10), # Adjust margins to prevent clipping
301
+ )
302
+
303
+ # Adjust the title location and font size
304
+ fig.update_layout(
305
+ title=dict(
306
+ y=0.98,
307
+ x=0.5, # Center the title horizontally
308
+ xanchor="center", # Anchor the title at the center
309
+ yanchor="top", # Anchor the title at the top
310
+ font=dict(
311
+ size=20 # Increase the title font size
312
+ ),
313
+ )
314
+ )
315
+
316
+ return fig
317
+
318
+
319
+ def _create_color_map(category_list: list, hex=False, rng=42) -> dict[str, tuple]:
320
+ unique_categories = sorted(set(category_list), key=str)
321
+
322
+ # Check for 'NaN' or nan and handle separately
323
+ nan_values = [v for v in unique_categories if str(v).lower() in ['nan', 'none', 'null']]
324
+ other_categories = [v for v in unique_categories if v not in nan_values]
325
+
326
+ n_colors = len(other_categories)
327
+
328
+ # Generate N visually distinct colors for non-NaN categories
329
+ if n_colors > 0:
330
+ colors = distinctipy.get_colors(n_colors, rng=rng)
331
+ color_map = dict(zip(other_categories, colors, strict=False))
332
+ else:
333
+ color_map = {}
334
+
335
+ # Assign grey color to NaN values
336
+ grey_rgb = (0.827, 0.827, 0.827) # lightgrey
337
+ for v in nan_values:
338
+ color_map[v] = grey_rgb
339
+
340
+ if hex:
341
+ # Convert RGB tuples to hex format
342
+ color_map = {category: distinctipy.get_hex(color_map[category]) for category in color_map}
343
+ print("Generated color map in hex format")
344
+ return color_map
345
+
346
+
347
+ class VisualizeRunner:
348
+ def __init__(self, config):
349
+ self.config = config
350
+
351
+ custom_colors_list = [
352
+ '#d73027', '#f46d43', '#fdae61', '#fee090', '#e0f3f8',
353
+ '#abd9e9', '#74add1', '#4575b4', '#313695'
354
+ ]
355
+
356
+ def _generate_visualizations(self, obs_ldsc_merged: pd.DataFrame):
357
+ """Generate all visualizations"""
358
+
359
+ # Create visualization directories
360
+ single_sample_folder = self.config.visualization_result_dir / 'single_sample_multi_trait_plot'
361
+ annotation_folder = self.config.visualization_result_dir / 'annotation_distribution'
362
+ annotation_folder.mkdir(exist_ok=True, parents=True)
363
+
364
+ sample_names_list = sorted(obs_ldsc_merged['sample_name'].unique())
365
+
366
+ for sample_name in tqdm(sample_names_list, desc='Generating visualizations'):
367
+ # Multi-trait plot
368
+ traits_png = single_sample_folder / 'static_png' / f'{sample_name}_gwas_traits_pvalues.jpg'
369
+ traits_pdf = single_sample_folder / 'static_pdf' / f'{sample_name}_gwas_traits_pvalues.pdf' # Added PDF output path
370
+
371
+ # Create parent directories for the output files
372
+ traits_png.parent.mkdir(exist_ok=True, parents=True)
373
+ traits_pdf.parent.mkdir(exist_ok=True, parents=True)
374
+
375
+ # Call the modified matplotlib-based plotting function.
376
+ # This function saves files directly and does not return a figure object.
377
+ self._create_single_sample_multi_trait_plots(
378
+ obs_ldsc_merged=obs_ldsc_merged,
379
+ trait_names=self.config.trait_name_list,
380
+ sample_name=sample_name,
381
+ output_png_path=traits_png,
382
+ output_pdf_path=traits_pdf,
383
+ max_cols=self.config.single_sample_multi_trait_max_cols,
384
+ subsample_n_points=self.config.subsample_n_points,
385
+ # Use new parameters from the updated VisualizationConfig
386
+ subplot_width_inches=self.config.single_sample_multi_trait_subplot_width_inches,
387
+ dpi=self.config.single_sample_multi_trait_dpi,
388
+ enable_pdf_output=self.config.enable_pdf_output
389
+ )
390
+
391
+ # Annotation distribution plots
392
+ sample_data = obs_ldsc_merged.query(f'sample_name == "{sample_name}"')
393
+ (pixel_width, pixel_height), point_size = estimate_plotly_point_size(sample_data[['sx', 'sy']].values)
394
+
395
+ for annotation in self.config.cauchy_annotations:
396
+ annotation_dir = annotation_folder / annotation
397
+ annotation_dir.mkdir(exist_ok=True)
398
+
399
+ annotation_color_map = _create_color_map(obs_ldsc_merged[annotation].unique(), hex=True)
400
+ fig = self._draw_scatter(sample_data, title=f'{annotation}_{sample_name}',
401
+ point_size=point_size, width=pixel_width, height=pixel_height,
402
+ hover_text_list=self.config.hover_text_list,
403
+ color_by=annotation, color_map=annotation_color_map)
404
+
405
+ annotation_png = annotation_dir / f'{sample_name}_{annotation}.png'
406
+ annotation_html = annotation_dir / f'{sample_name}_{annotation}.html'
407
+ fig.write_image(annotation_png)
408
+ fig.write_html(annotation_html)
409
+
410
+ # Generate multi-sample annotation plots
411
+ print("Generating multi-sample annotation plots...")
412
+ sample_count = len(sample_names_list)
413
+ n_rows, n_cols = self._calculate_optimal_grid_layout(
414
+ item_count=sample_count,
415
+ max_cols=self.config.single_sample_multi_trait_max_cols
416
+ )
417
+
418
+ for annotation in tqdm(self.config.cauchy_annotations,
419
+ desc='Generating multi-sample annotation plots'):
420
+ annotation_dir = annotation_folder / annotation
421
+ annotation_dir.mkdir(exist_ok=True)
422
+
423
+ self._create_multi_sample_annotation_plot(
424
+ obs_ldsc_merged=obs_ldsc_merged,
425
+ annotation=annotation,
426
+ sample_names_list=sample_names_list,
427
+ output_dir=annotation_dir,
428
+ n_rows=n_rows,
429
+ n_cols=n_cols
430
+ )
431
+
432
+ def _create_single_trait_multi_sample_plots(self, obs_ldsc_merged: pd.DataFrame):
433
+ """Generate single trait multi-sample visualizations using matplotlib"""
434
+
435
+ trait_names = self.config.trait_name_list
436
+
437
+ # Create output directory
438
+ single_trait_folder = self.config.visualization_result_dir / 'single_trait_multi_sample_plot'
439
+ single_trait_folder.mkdir(exist_ok=True, parents=True)
440
+
441
+ # Prepare coordinate columns (assuming sx, sy are the spatial coordinates)
442
+ obs_ldsc_merged = obs_ldsc_merged.copy()
443
+
444
+ # Get sample count to determine grid dimensions
445
+ sample_count = obs_ldsc_merged['sample_name'].nunique()
446
+ n_rows, n_cols = self._calculate_optimal_grid_layout(
447
+ item_count=sample_count,
448
+ max_cols=self.config.single_trait_multi_sample_max_cols
449
+ )
450
+
451
+ print(f"Generating plots for {len(trait_names)} traits with {sample_count} samples in {n_rows}x{n_cols} grid")
452
+
453
+ # Generate visualization for each trait
454
+ for trait in tqdm(trait_names, desc="Generating single trait multi-sample plots"):
455
+ if trait not in obs_ldsc_merged.columns:
456
+ print(f"Warning: Trait {trait} not found in data. Skipping.")
457
+ continue
458
+
459
+ self._create_single_trait_multi_sample_matplotlib_plot(
460
+ obs_ldsc_merged=obs_ldsc_merged,
461
+ trait_abbreviation=trait,
462
+ output_png_path=single_trait_folder / f'{trait}_multi_sample_plot.jpg',
463
+ output_pdf_path=single_trait_folder / f'{trait}_multi_sample_plot.pdf',
464
+ n_rows=n_rows,
465
+ n_cols=n_cols,
466
+ subplot_width_inches=self.config.single_trait_multi_sample_subplot_width_inches,
467
+ scaling_factor=self.config.single_trait_multi_sample_scaling_factor,
468
+ dpi=self.config.single_trait_multi_sample_dpi,
469
+ enable_pdf_output=self.config.enable_pdf_output,
470
+ share_coords=self.config.share_coords
471
+ )
472
+
473
+ def _calculate_optimal_grid_layout(self, item_count: int, max_cols: int = 8) -> tuple[int, int]:
474
+ """
475
+ Calculate optimal grid dimensions (rows, cols) for displaying items in a grid.
476
+
477
+ Args:
478
+ item_count: Number of items to display
479
+ max_cols: Maximum number of columns allowed
480
+
481
+ Returns:
482
+ tuple: (n_rows, n_cols) for optimal grid layout
483
+ """
484
+ import math
485
+
486
+ if item_count <= 0:
487
+ return 1, 1
488
+
489
+ # For small counts, use simple layouts favoring horizontal arrangement
490
+ if item_count <= 3:
491
+ return 1, item_count
492
+ elif item_count <= 6:
493
+ return 2, math.ceil(item_count / 2)
494
+ elif item_count <= 12:
495
+ return 3, math.ceil(item_count / 3)
496
+ else:
497
+ # For larger counts, try to create a roughly square grid
498
+ # but respect the max_cols constraint
499
+ optimal_cols = min(math.ceil(math.sqrt(item_count)), max_cols)
500
+ optimal_rows = math.ceil(item_count / optimal_cols)
501
+
502
+ # If we hit the max_cols limit, recalculate rows
503
+ if optimal_cols >= max_cols:
504
+ n_cols = max_cols
505
+ n_rows = math.ceil(item_count / max_cols)
506
+ else:
507
+ n_rows = optimal_rows
508
+ n_cols = optimal_cols
509
+
510
+ print(f"Calculated grid layout: {n_rows} rows × {n_cols} cols for {item_count} items")
511
+ return n_rows, n_cols
512
+
513
+ def _create_single_trait_multi_sample_matplotlib_plot(self, obs_ldsc_merged: pd.DataFrame, trait_abbreviation: str,
514
+ sample_name_list: list[str] | None = None,
515
+ output_png_path: Path | None = None,
516
+ output_pdf_path: Path | None = None,
517
+ n_rows: int = 6, n_cols: int = 8,
518
+ subplot_width_inches: float = 4.0,
519
+ scaling_factor: float = 1.0, dpi: int = 300,
520
+ enable_pdf_output: bool = True,
521
+ show=False,
522
+ share_coords: bool = False
523
+ ):
524
+ """
525
+ Create and save a visualization for a specific trait showing all samples
526
+ """
527
+
528
+ matplotlib.rcParams['figure.dpi'] = dpi
529
+ print(f"Creating visualization for {trait_abbreviation}")
530
+
531
+ # Check if trait exists in the dataframe
532
+ if trait_abbreviation not in obs_ldsc_merged.columns:
533
+ print(f"Warning: Trait {trait_abbreviation} not found in the data. Skipping.")
534
+ return
535
+
536
+ # Set font to Arial with fallbacks to avoid warnings
537
+ plt.rcParams['font.family'] = 'sans-serif'
538
+ plt.rcParams['font.sans-serif'] = ['Arial', 'DejaVu Sans', 'Liberation Sans', 'Bitstream Vera Sans', 'sans-serif']
539
+
540
+ custom_cmap = mcolors.LinearSegmentedColormap.from_list('custom_cmap', self.custom_colors_list)
541
+ custom_cmap = custom_cmap.reversed()
542
+
543
+ # Calculate figure size based on subplot dimensions
544
+ fig_width = n_cols * subplot_width_inches
545
+ fig_height = n_rows * subplot_width_inches
546
+
547
+ # Create figure with title
548
+ fig = plt.figure(figsize=(fig_width, fig_height))
549
+
550
+ # Add main title
551
+ fig.suptitle(trait_abbreviation, fontsize=24, fontweight='bold', y=0.98)
552
+
553
+ # Create grid of subplots
554
+ grid_specs = fig.add_gridspec(nrows=n_rows, ncols=n_cols, wspace=0.1, hspace=0.1)
555
+
556
+ _, pass_filter_mask = remove_outliers_MAD(obs_ldsc_merged[trait_abbreviation])
557
+ obs_ldsc_merged_filtered = obs_ldsc_merged[pass_filter_mask]
558
+
559
+ pd_min = 0
560
+ pd_max = obs_ldsc_merged_filtered[trait_abbreviation].quantile(0.999)
561
+
562
+ print(f"Color scale min: {pd_min}, max: {pd_max}")
563
+ # Get list of sample names - use provided list or fallback to sorted unique
564
+ if sample_name_list is None:
565
+ sample_name_list = sorted(obs_ldsc_merged_filtered['sample_name'].unique())
566
+
567
+ # get the x and y limit if share coordinates
568
+ if share_coords:
569
+ x_limits = (obs_ldsc_merged_filtered['sx'].min(), obs_ldsc_merged_filtered['sx'].max())
570
+ y_limits = (obs_ldsc_merged_filtered['sy'].min(), obs_ldsc_merged_filtered['sy'].max())
571
+ else:
572
+ x_limits = None
573
+ y_limits = None
574
+
575
+ # Create a scatter plot for each sample
576
+ for position_num, select_sample_name in enumerate(sample_name_list[:n_rows * n_cols], 1):
577
+ # Calculate row and column in the grid
578
+ row = (position_num - 1) // n_cols
579
+ col = (position_num - 1) % n_cols
580
+
581
+ # Create subplot
582
+ ax = fig.add_subplot(grid_specs[row, col])
583
+
584
+ # Get data for this sample
585
+ sample_data = obs_ldsc_merged_filtered[obs_ldsc_merged_filtered['sample_name'] == select_sample_name]
586
+
587
+ point_size = self.estimate_matplitlib_scatter_marker_size(ax, sample_data[['sx', 'sy']].values,
588
+ x_limits=x_limits, y_limits=y_limits)
589
+ point_size *= scaling_factor # Apply scaling factor
590
+ # Create scatter plot
591
+ scatter = ax.scatter(
592
+ sample_data['sx'],
593
+ sample_data['sy'],
594
+ c=sample_data[trait_abbreviation],
595
+ cmap=custom_cmap,
596
+ s=point_size,
597
+ vmin=pd_min,
598
+ vmax=pd_max,
599
+ marker='o',
600
+ edgecolors='none',
601
+ rasterized=True if output_pdf_path is not None and enable_pdf_output else False
602
+ )
603
+
604
+ if self.config.plot_origin == 'upper':
605
+ ax.invert_yaxis()
606
+
607
+ ax.axis('off')
608
+ # Add sample label as title
609
+ ax.set_title(select_sample_name, fontsize=12, pad=None)
610
+
611
+ # Add colorbar to the right side
612
+ cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7]) # [left, bottom, width, height]
613
+ cbar = fig.colorbar(scatter, cax=cbar_ax)
614
+ cbar.set_label('$-\\log_{10}p$', fontsize=12, fontweight='bold')
615
+
616
+ if output_png_path is not None:
617
+ output_png_path.parent.mkdir(parents=True, exist_ok=True)
618
+ plt.savefig(output_png_path, dpi=dpi, bbox_inches='tight', )
619
+
620
+ if output_pdf_path is not None and enable_pdf_output:
621
+ output_pdf_path.parent.mkdir(parents=True, exist_ok=True)
622
+ plt.savefig(output_pdf_path, bbox_inches='tight', )
623
+
624
+ # Close the figure to free memory only if not returning it
625
+ if show:
626
+ plt.show()
627
+
628
+ gc.collect()
629
+ return fig
630
+
631
+ def _create_single_sample_multi_trait_plots(self,
632
+ obs_ldsc_merged: pd.DataFrame,
633
+ trait_names: list[str],
634
+ sample_name: str,
635
+ # New arguments for output paths, as matplotlib saves directly
636
+ output_png_path: Path | None,
637
+ output_pdf_path: Path | None,
638
+ # Arguments from original function signature, adapted for matplotlib
639
+ max_cols: int = 5,
640
+ subsample_n_points: int | None = None,
641
+ # subplot_width is now interpreted as inches for figsize
642
+ subplot_width_inches: float = 4.0,
643
+ dpi: int = 300,
644
+ enable_pdf_output: bool = True
645
+ ):
646
+
647
+ print(f"Creating Matplotlib-based multi-trait visualization for sample: {sample_name}")
648
+
649
+ # 1. Filter data for the specific sample and subsample if requested
650
+ sample_plot_data = obs_ldsc_merged[obs_ldsc_merged['sample_name'] == sample_name].copy()
651
+ if subsample_n_points and len(sample_plot_data) > subsample_n_points:
652
+ print(f"Subsampling to {subsample_n_points} points for plotting.")
653
+ sample_plot_data = sample_plot_data.sample(n=subsample_n_points, random_state=42)
654
+
655
+ if sample_plot_data.empty:
656
+ print(f"Warning: No data found for sample '{sample_name}'. Skipping plot generation.")
657
+ return
658
+
659
+ # 2. Calculate optimal grid layout for subplots
660
+ n_traits = len(trait_names)
661
+ n_rows, n_cols = self._calculate_optimal_grid_layout(item_count=n_traits, max_cols=max_cols)
662
+ print(f"Plotting {n_traits} traits in a {n_rows}x{n_cols} grid.")
663
+
664
+ # 3. Determine figure size and create figure and axes
665
+ # Estimate subplot height based on data's aspect ratio to avoid distortion
666
+ x_range = sample_plot_data['sx'].max() - sample_plot_data['sx'].min()
667
+ y_range = sample_plot_data['sy'].max() - sample_plot_data['sy'].min()
668
+ aspect_ratio = y_range / x_range if x_range > 0 else 1.0
669
+ subplot_height_inches = subplot_width_inches * aspect_ratio
670
+
671
+ # Calculate total figure size, adding padding for titles and colorbars
672
+ fig_width = subplot_width_inches * n_cols
673
+ fig_height = (subplot_height_inches * n_rows) * 1.2 # Add 20% vertical space for titles/colorbars
674
+
675
+ fig, axes = plt.subplots(n_rows, n_cols, figsize=(fig_width, fig_height), squeeze=False)
676
+ fig.suptitle(f"Sample: {sample_name}", fontsize=16, fontweight='bold')
677
+
678
+ # 4. Define custom colormap and font
679
+ # Set font to Arial with fallbacks to avoid warnings
680
+ plt.rcParams['font.family'] = 'sans-serif'
681
+ plt.rcParams['font.sans-serif'] = ['Arial', 'DejaVu Sans', 'Liberation Sans', 'Bitstream Vera Sans', 'sans-serif']
682
+
683
+ custom_cmap = matplotlib.colors.LinearSegmentedColormap.from_list('custom_cmap',
684
+ self.custom_colors_list).reversed()
685
+
686
+ # 5. Iterate through traits and create each subplot
687
+ axes_flat = axes.flatten()
688
+ for i, trait in enumerate(trait_names):
689
+ if i >= len(axes_flat):
690
+ break # Should not happen with correct grid calculation, but is a safe guard
691
+
692
+ ax = axes_flat[i]
693
+
694
+ # Estimate marker size to fill space without overlap
695
+ point_size = estimate_matplotlib_scatter_marker_size(ax, sample_plot_data[['sx', 'sy']].values)
696
+
697
+ # Determine color scale, capping at the 99.9th percentile to handle outliers
698
+ sample_trait_data = sample_plot_data[['sx', 'sy', trait]].dropna()
699
+ trait_values, mask = remove_outliers_MAD(sample_trait_data[trait])
700
+ sample_trait_data = sample_trait_data[mask] # filter out outliers
701
+
702
+ vmin = 0
703
+ vmax = trait_values.quantile(0.999)
704
+ if pd.isna(vmax) or vmax == 0:
705
+ vmax = trait_values.max() if trait_values.max() > 0 else 1.0
706
+
707
+ # Create the scatter plot
708
+ scatter = ax.scatter(
709
+ sample_trait_data['sx'],
710
+ sample_trait_data['sy'],
711
+ c=trait_values,
712
+ cmap=custom_cmap,
713
+ s=point_size,
714
+ vmin=vmin,
715
+ vmax=vmax,
716
+ marker='o',
717
+ edgecolors='none',
718
+ rasterized=True if output_pdf_path is not None and enable_pdf_output else False
719
+ )
720
+
721
+ ax.set_title(trait, fontsize=16, pad=10, fontweight='bold')
722
+ ax.set_aspect('equal', adjustable='box')
723
+
724
+ if self.config.plot_origin == 'upper':
725
+ ax.invert_yaxis()
726
+
727
+ ax.axis('off')
728
+
729
+ # Add a colorbar to each subplot
730
+ cbar = fig.colorbar(scatter, ax=ax, orientation='horizontal', pad=0.1, fraction=0.05)
731
+ cbar.set_label('$-\\log_{10}p$', fontsize=8)
732
+ cbar.ax.tick_params(labelsize=7)
733
+
734
+ # Hide any unused axes in the grid
735
+ for j in range(len(trait_names), len(axes_flat)):
736
+ axes_flat[j].axis('off')
737
+ #
738
+ # # 6. Adjust layout and save the figure
739
+ # fig.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust for suptitle and bottom elements
740
+
741
+ # Only proceed with saving if paths are provided
742
+ if output_png_path is not None:
743
+ output_png_path.parent.mkdir(parents=True, exist_ok=True)
744
+ plt.savefig(output_png_path, dpi=dpi, bbox_inches='tight', facecolor='white')
745
+ print(f"Saved multi-trait plot for '{sample_name}' to:\n - {output_png_path}")
746
+
747
+ if output_pdf_path is not None and enable_pdf_output:
748
+ output_pdf_path.parent.mkdir(parents=True, exist_ok=True)
749
+ plt.savefig(output_pdf_path, bbox_inches='tight', facecolor='white')
750
+ print(f"Saved multi-trait plot for '{sample_name}' to:\n - {output_pdf_path}")
751
+
752
+ # Clean up to free memory
753
+ plt.close(fig)
754
+
755
+ def _draw_scatter(self, space_coord_concat: pd.DataFrame, title: str | None = None,
756
+ fig_style: str = 'light', point_size: int | None = None,
757
+ hover_text_list: list[str] | None = None,
758
+ width: int = 800, height: int = 600, annotation: str | None = None,
759
+ color_by: str = 'logp', color_map: dict | None = None):
760
+ """Create scatter plot (adapted from original draw_scatter function)"""
761
+ # Set theme based on fig_style
762
+ if fig_style == 'dark':
763
+ px.defaults.template = "plotly_dark"
764
+ else:
765
+ px.defaults.template = "plotly_white"
766
+
767
+ custom_color_scale = [
768
+ (1, '#d73027'), # Red
769
+ (7 / 8, '#f46d43'), # Red-Orange
770
+ (6 / 8, '#fdae61'), # Orange
771
+ (5 / 8, '#fee090'), # Light Orange
772
+ (4 / 8, '#e0f3f8'), # Light Blue
773
+ (3 / 8, '#abd9e9'), # Sky Blue
774
+ (2 / 8, '#74add1'), # Medium Blue
775
+ (1 / 8, '#4575b4'), # Dark Blue
776
+ (0, '#313695') # Deep Blue
777
+ ]
778
+ custom_color_scale.reverse()
779
+
780
+ # if category data
781
+ if not pd.api.types.is_numeric_dtype(space_coord_concat[color_by]):
782
+ # Create the scatter plot
783
+ fig = px.scatter(
784
+ space_coord_concat,
785
+ x='sx',
786
+ y='sy',
787
+ color=color_by,
788
+ # symbol=annotation,
789
+ title=title,
790
+ color_discrete_map=color_map,
791
+ hover_name=color_by,
792
+ hover_data=hover_text_list,
793
+ # color_continuous_scale=custom_color_scale,
794
+ # range_color=[0, max(space_coord_concat[color_by])],
795
+ )
796
+ else:
797
+ fig = px.scatter(
798
+ space_coord_concat,
799
+ x='sx',
800
+ y='sy',
801
+ color=color_by,
802
+ symbol=annotation,
803
+ title=title,
804
+ hover_name=color_by,
805
+ hover_data=hover_text_list,
806
+ color_continuous_scale=custom_color_scale,
807
+ range_color=[0, space_coord_concat[color_by].max()],
808
+ )
809
+
810
+ # Update marker size if specified
811
+ if point_size is not None:
812
+ fig.update_traces(marker=dict(size=point_size, symbol='circle'))
813
+
814
+ # Update layout for figure size
815
+ fig.update_layout(
816
+ autosize=False,
817
+ width=width,
818
+ height=height,
819
+ )
820
+
821
+ # Adjusting the legend - Updated position and marker size
822
+ fig.update_layout(
823
+ legend=dict(
824
+ yanchor="middle", # Anchor point for y
825
+ y=0.5, # Center vertically
826
+ xanchor="left", # Anchor point for x
827
+ x=1.02, # Position just outside the plot
828
+ font=dict(
829
+ size=10,
830
+ ),
831
+ itemsizing='constant', # Makes legend markers a constant size
832
+ itemwidth=30, # Adjust width of legend items
833
+ )
834
+ )
835
+
836
+ # Update colorbar to be at the bottom and horizontal
837
+ fig.update_layout(
838
+ coloraxis_colorbar=dict(
839
+ orientation='h',
840
+ x=0.5,
841
+ y=-0.0,
842
+ xanchor='center',
843
+ yanchor='top',
844
+ len=0.75,
845
+ title=dict(
846
+ text='-log10(p)' if color_by == 'logp' else color_by,
847
+ side='top'
848
+ )
849
+ )
850
+ )
851
+
852
+ # Remove gridlines, axis labels, and ticks
853
+ fig.update_xaxes(
854
+ showgrid=False,
855
+ zeroline=False,
856
+ showticklabels=False,
857
+ title=None,
858
+ scaleanchor='y',
859
+ )
860
+
861
+ fig.update_yaxes(
862
+ showgrid=False,
863
+ zeroline=False,
864
+ showticklabels=False,
865
+ title=None,
866
+ autorange='reversed' if self.config.plot_origin == 'upper' else True
867
+ )
868
+
869
+ # Adjust margins to ensure no clipping and equal axis ratio
870
+ fig.update_layout(
871
+ margin=dict(l=0, r=100, t=20, b=10), # Increased right margin to accommodate legend
872
+ height=width
873
+ )
874
+
875
+ # Adjust the title location and font size
876
+ fig.update_layout(
877
+ title=dict(
878
+ y=0.98,
879
+ x=0.5,
880
+ xanchor='center',
881
+ yanchor='top',
882
+ font=dict(
883
+ size=20
884
+ )
885
+ ))
886
+
887
+ return fig
888
+
889
+ @classmethod
890
+ def estimate_matplitlib_scatter_marker_size(cls, ax: matplotlib.axes.Axes, coordinates: np.ndarray,
891
+ x_limits: tuple | None = None,
892
+ y_limits: tuple | None = None) -> float:
893
+ """Alias for estimate_matplotlib_scatter_marker_size (with typo) for backward compatibility."""
894
+ return estimate_matplotlib_scatter_marker_size(ax, coordinates, x_limits, y_limits)
895
+
896
+ @classmethod
897
+ def estimate_matplotlib_scatter_marker_size(cls, ax: matplotlib.axes.Axes, coordinates: np.ndarray,
898
+ x_limits: tuple | None = None,
899
+ y_limits: tuple | None = None) -> float:
900
+ """Alias for estimate_matplotlib_scatter_marker_size for backward compatibility."""
901
+ return estimate_matplotlib_scatter_marker_size(ax, coordinates, x_limits, y_limits)
902
+
903
+ def _create_multi_sample_annotation_plot(self, obs_ldsc_merged: pd.DataFrame, annotation: str,
904
+ sample_names_list: list, output_dir: Path,
905
+ n_rows: int, n_cols: int,
906
+ fig_width: float = 20, fig_height: float = 15,
907
+ scaling_factor: float = 1.0, dpi: int = 300):
908
+ """Create multi-sample annotation plot using matplotlib with subplots for each sample"""
909
+
910
+ print(f"Creating multi-sample plot for annotation: {annotation}")
911
+
912
+ # Create figure
913
+ fig = plt.figure(figsize=(fig_width, fig_height))
914
+ fig.suptitle(f'{annotation} - All Samples', fontsize=24, fontweight='bold', y=0.98)
915
+
916
+ # Create grid of subplots
917
+ grid_specs = fig.add_gridspec(nrows=n_rows, ncols=n_cols, wspace=0.1, hspace=0.1)
918
+
919
+ # Get unique annotation values and create color map
920
+ unique_annotations = obs_ldsc_merged[annotation].unique()
921
+ if pd.api.types.is_numeric_dtype(obs_ldsc_merged[annotation]):
922
+
923
+ custom_cmap = matplotlib.colors.LinearSegmentedColormap.from_list('custom_cmap', self.custom_colors_list)
924
+ cmap = custom_cmap.reversed()
925
+ norm = plt.Normalize(vmin=obs_ldsc_merged[annotation].min(),
926
+ vmax=obs_ldsc_merged[annotation].max())
927
+ else:
928
+ # For categorical annotations, use discrete colors
929
+ color_map = _create_color_map(unique_annotations, hex=False)
930
+
931
+ # Create scatter plot for each sample
932
+ for position_num, sample_name in enumerate(sample_names_list[:n_rows * n_cols], 1):
933
+ # Calculate row and column in the grid
934
+ row = (position_num - 1) // n_cols
935
+ col = (position_num - 1) % n_cols
936
+
937
+ # Create subplot
938
+ ax = fig.add_subplot(grid_specs[row, col])
939
+
940
+ # Get data for this sample
941
+ sample_data = obs_ldsc_merged[obs_ldsc_merged['sample_name'] == sample_name]
942
+
943
+ # Estimate point size based on data density
944
+ point_size = estimate_matplotlib_scatter_marker_size(ax, sample_data[['sx', 'sy']].values)
945
+ point_size *= scaling_factor # Apply scaling factor
946
+
947
+ # Create scatter plot
948
+ if pd.api.types.is_numeric_dtype(obs_ldsc_merged[annotation]):
949
+ ax.scatter(sample_data['sx'], sample_data['sy'],
950
+ c=sample_data[annotation], cmap=cmap, norm=norm,
951
+ s=point_size, alpha=1.0, edgecolors='none')
952
+ else:
953
+ # For categorical data, plot each category separately
954
+ for cat in unique_annotations:
955
+ cat_data = sample_data[sample_data[annotation] == cat]
956
+ if len(cat_data) > 0:
957
+ ax.scatter(cat_data['sx'], cat_data['sy'],
958
+ c=[color_map[cat]], s=point_size, alpha=1.0,
959
+ edgecolors='none', label=cat)
960
+
961
+ # Set subplot title
962
+ ax.set_title(sample_name, fontsize=10)
963
+ ax.set_aspect('equal')
964
+ if self.config.plot_origin == 'upper':
965
+ ax.invert_yaxis()
966
+ ax.axis('off')
967
+
968
+ # Add colorbar for numeric annotations or legend for categorical
969
+ if pd.api.types.is_numeric_dtype(obs_ldsc_merged[annotation]):
970
+ # Create a colorbar on the right side of the figure
971
+ sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
972
+ sm.set_array([])
973
+ fig.subplots_adjust(right=0.85)
974
+ cbar_ax = fig.add_axes([0.88, 0.15, 0.03, 0.7])
975
+ cbar = fig.colorbar(sm, cax=cbar_ax, orientation='vertical')
976
+ cbar.set_label(annotation, fontsize=14)
977
+ else:
978
+ # Create a legend on the right side of the figure
979
+ handles = [plt.Line2D([0], [0], marker='o', color='w', label=label,
980
+ markerfacecolor=color, markersize=getattr(self.config, 'legend_marker_size', 10))
981
+ for label, color in color_map.items()]
982
+ fig.subplots_adjust(right=0.8)
983
+ fig.legend(handles=handles, title=annotation, loc='center left', bbox_to_anchor=(0.85, 0.5))
984
+
985
+ # Save the plot if output_dir is provided
986
+ if output_dir:
987
+ output_path = output_dir / f'multi_sample_{annotation}.png'
988
+ plt.savefig(output_path, dpi=dpi, bbox_inches='tight', facecolor='white')
989
+ print(f"Saved multi-sample annotation plot: {output_path}")
990
+
991
+ return fig
992
+
993
+ def _run_cauchy_analysis(self, obs_ldsc_merged: pd.DataFrame):
994
+ """Run Cauchy combination analysis"""
995
+
996
+ trait_names = self.config.trait_name_list
997
+
998
+ for annotation_col in self.config.cauchy_annotations:
999
+ print(f"Running Cauchy analysis for {annotation_col}...")
1000
+
1001
+ cauchy_results = self._run_cauchy_combination_per_annotation(
1002
+ obs_ldsc_merged, annotation_col=annotation_col, trait_cols=trait_names)
1003
+
1004
+ # Save results
1005
+ output_file = self.config.visualization_result_dir / f'cauchy_results_{annotation_col}.csv'
1006
+ self._save_cauchy_results_to_csv(cauchy_results, output_file)
1007
+
1008
+ # Generate heatmaps
1009
+ self._generate_cauchy_heatmaps(cauchy_results, annotation_col)
1010
+
1011
+ def _run_cauchy_combination_per_annotation(self, df: pd.DataFrame, annotation_col: str,
1012
+ trait_cols: list[str], max_workers=None):
1013
+ """
1014
+ Runs the Cauchy combination on each annotation category for each given trait in parallel.
1015
+ Also calculates odds ratios with confidence intervals for significant spots in each annotation.
1016
+ """
1017
+ from functools import partial
1018
+
1019
+ import statsmodels.api as sm
1020
+ from scipy.stats import fisher_exact
1021
+
1022
+ results_dict = {}
1023
+ annotations = df[annotation_col].unique()
1024
+
1025
+ # Helper function to process a single trait for a given annotation
1026
+ def process_trait(trait, anno_data, all_data, annotation):
1027
+ # Calculate significance threshold (Bonferroni correction)
1028
+ sig_threshold = 0.05 / len(all_data)
1029
+
1030
+ # Get p-values for this annotation and trait
1031
+ log10p = anno_data[trait].values
1032
+ log10p, mask = remove_outliers_MAD(log10p, )
1033
+ p_values = 10 ** (-log10p) # convert from log10(p) to p
1034
+
1035
+ # Calculate Cauchy combination and median
1036
+ p_cauchy_val = self._acat_test(p_values)
1037
+ p_median_val = np.median(p_values)
1038
+
1039
+ # Calculate significance statistics
1040
+ sig_spots_in_anno = np.sum(p_values < sig_threshold)
1041
+ total_spots_in_anno = len(p_values)
1042
+
1043
+ # Get p-values for other annotations
1044
+ other_annotations_mask = all_data[annotation_col] != annotation
1045
+ other_p_values = 10 ** (-all_data.loc[other_annotations_mask, trait].values)
1046
+ sig_spots_elsewhere = np.sum(other_p_values < sig_threshold)
1047
+ total_spots_elsewhere = len(other_p_values)
1048
+
1049
+ # Odds ratio calculation using Fisher's exact test
1050
+ try:
1051
+ # Create contingency table
1052
+ contingency_table = np.array([
1053
+ [sig_spots_in_anno, total_spots_in_anno - sig_spots_in_anno],
1054
+ [sig_spots_elsewhere, total_spots_elsewhere - sig_spots_elsewhere]
1055
+ ])
1056
+
1057
+ # Calculate odds ratio and p-value using Fisher's exact test
1058
+ odds_ratio, p_value = fisher_exact(contingency_table)
1059
+
1060
+ # if odds_ratio is infinite, set it to a large number
1061
+ if odds_ratio == np.inf:
1062
+ odds_ratio = 1e4 # Set to a large number to avoid overflow
1063
+
1064
+ # Calculate confidence intervals
1065
+ table = sm.stats.Table2x2(contingency_table)
1066
+ conf_int = table.oddsratio_confint()
1067
+ ci_low, ci_high = conf_int
1068
+ except Exception as e:
1069
+ # Handle calculation errors
1070
+ odds_ratio = 0
1071
+ p_value = 1
1072
+ ci_low, ci_high = 0, 0
1073
+ print(f"Fisher's exact test failed for {trait} in {annotation}: {e}")
1074
+
1075
+ return {
1076
+ 'trait': trait,
1077
+ 'p_cauchy': p_cauchy_val,
1078
+ 'p_median': p_median_val,
1079
+ 'odds_ratio': odds_ratio,
1080
+ 'ci_low': ci_low,
1081
+ 'ci_high': ci_high,
1082
+ 'p_odds_ratio': p_value,
1083
+ 'sig_spots': sig_spots_in_anno,
1084
+ 'total_spots': total_spots_in_anno,
1085
+ 'sig_ratio': sig_spots_in_anno / total_spots_in_anno if total_spots_in_anno > 0 else 0,
1086
+ 'overall_sig_spots': sig_spots_in_anno + sig_spots_elsewhere,
1087
+ 'overall_spots': total_spots_in_anno + total_spots_elsewhere
1088
+ }
1089
+
1090
+ # Process each annotation (sequential)
1091
+ for anno in tqdm(annotations, desc="Processing annotations"):
1092
+ df_anno = df[df[annotation_col] == anno]
1093
+
1094
+ # Create a partial function with fixed parameters
1095
+ process_trait_for_anno = partial(process_trait,
1096
+ anno_data=df_anno,
1097
+ all_data=df,
1098
+ annotation=anno)
1099
+
1100
+ # Process traits in parallel with progress bar
1101
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
1102
+ # Create list for results and submit all tasks
1103
+ futures = list(tqdm(
1104
+ executor.map(process_trait_for_anno, trait_cols),
1105
+ total=len(trait_cols),
1106
+ desc=f"Processing traits for {anno}",
1107
+ leave=False
1108
+ ))
1109
+ trait_results = list(futures)
1110
+
1111
+ # Create a DataFrame for this annotation
1112
+ anno_results_df = pd.DataFrame(trait_results).sort_values(by='p_cauchy')
1113
+ results_dict[anno] = anno_results_df
1114
+
1115
+ return results_dict
1116
+
1117
+ def _acat_test(self, pvalues: np.ndarray, weights=None):
1118
+ logger = logging.getLogger('gsMap.post_analysis.cauchy')
1119
+ if np.any(np.isnan(pvalues)):
1120
+ raise ValueError("Cannot have NAs in the p-values.")
1121
+ if np.any((pvalues > 1) | (pvalues < 0)):
1122
+ raise ValueError("P-values must be between 0 and 1.")
1123
+ if np.any(pvalues == 0) and np.any(pvalues == 1):
1124
+ raise ValueError("Cannot have both 0 and 1 p-values.")
1125
+ if np.any(pvalues == 0):
1126
+ logger.info("Warn: p-values are exactly 0.")
1127
+ return 0
1128
+ if np.any(pvalues == 1):
1129
+ logger.info("Warn: p-values are exactly 1.")
1130
+ return 1
1131
+
1132
+ if weights is None:
1133
+ weights = np.full(len(pvalues), 1 / len(pvalues))
1134
+ else:
1135
+ if len(weights) != len(pvalues):
1136
+ raise Exception("Length of weights and p-values differs.")
1137
+ if any(weights < 0):
1138
+ raise Exception("All weights must be positive.")
1139
+ weights = np.array(weights) / np.sum(weights)
1140
+
1141
+ is_small = pvalues < 1e-16
1142
+ is_large = ~is_small
1143
+
1144
+ if not np.any(is_small):
1145
+ cct_stat = np.sum(weights * np.tan((0.5 - pvalues) * np.pi))
1146
+ else:
1147
+ cct_stat = np.sum((weights[is_small] / pvalues[is_small]) / np.pi) + \
1148
+ np.sum(weights[is_large] * np.tan((0.5 - pvalues[is_large]) * np.pi))
1149
+
1150
+ if cct_stat > 1e15:
1151
+ pval = (1 / cct_stat) / np.pi
1152
+ else:
1153
+ pval = 1 - stats.cauchy.cdf(cct_stat)
1154
+
1155
+ return pval
1156
+
1157
+ def _save_cauchy_results_to_csv(self, cauchy_results: dict, output_path: Path):
1158
+ """Save Cauchy results to CSV"""
1159
+ all_results = []
1160
+ for annotation, df in cauchy_results.items():
1161
+ df_copy = df.copy()
1162
+ df_copy['annotation'] = annotation
1163
+ all_results.append(df_copy)
1164
+
1165
+ combined_results = pd.concat(all_results, ignore_index=True)
1166
+ output_path.parent.mkdir(parents=True, exist_ok=True)
1167
+ combined_results.to_csv(output_path, index=False)
1168
+
1169
+ return combined_results
1170
+
1171
+ def _generate_cauchy_heatmaps(self, cauchy_results: dict, annotation_col: str):
1172
+ """Generate multiple types of Cauchy combination heatmaps"""
1173
+ # Convert results to pivot table format for different metrics
1174
+ table_cauchy = self._results_dict_to_log10_table(cauchy_results, value_col='p_cauchy', log10_transform=True)
1175
+ table_median = self._results_dict_to_log10_table(cauchy_results, value_col='p_median', log10_transform=True)
1176
+ table_odds_ratio = self._results_dict_to_log10_table(cauchy_results, value_col='odds_ratio',
1177
+ log10_transform=False)
1178
+
1179
+ # Create heatmap directories
1180
+ cauchy_heatmap_base = self.config.visualization_result_dir / 'cauchy_heatmap'
1181
+ static_folder = cauchy_heatmap_base / 'static_png'
1182
+ interactive_folder = cauchy_heatmap_base / 'interactive_html'
1183
+
1184
+ for folder in [static_folder, interactive_folder]:
1185
+ folder.mkdir(exist_ok=True, parents=True)
1186
+
1187
+ # Calculate dimensions
1188
+ num_annotations, num_traits = table_cauchy.shape
1189
+ width = 50 * num_traits
1190
+ height = 50 * num_annotations
1191
+
1192
+ # 1. Cauchy combination heatmap (non-normalized)
1193
+ fig = self._plot_p_cauchy_heatmap(
1194
+ df=table_cauchy,
1195
+ title=f"Cauchy Combination Heatmap -- By {annotation_col}",
1196
+ cluster_rows=True, cluster_cols=True,
1197
+ width=width, height=height,
1198
+ text_format=".2f", font_size=10, margin_pad=150
1199
+ )
1200
+ fig.write_image(static_folder / f'cauchy_combination_by_{annotation_col}.png', scale=2)
1201
+ fig.write_html(interactive_folder / f'cauchy_combination_by_{annotation_col}.html')
1202
+
1203
+ # 2. Cauchy combination heatmap (normalized)
1204
+ fig = self._plot_p_cauchy_heatmap(
1205
+ df=table_cauchy,
1206
+ title=f"Cauchy Combination Heatmap -- By {annotation_col}",
1207
+ normalize_axis='column',
1208
+ cluster_rows=True, cluster_cols=True,
1209
+ width=width, height=height,
1210
+ text_format=".2f", font_size=10, margin_pad=150
1211
+ )
1212
+ fig.write_image(static_folder / f'cauchy_combination_by_{annotation_col}_normalized.png', scale=2)
1213
+ fig.write_html(interactive_folder / f'cauchy_combination_by_{annotation_col}_normalized.html')
1214
+
1215
+ # 3. Median p-value heatmap (non-normalized)
1216
+ fig = self._plot_p_cauchy_heatmap(
1217
+ df=table_median,
1218
+ title=f"Median log 10 pvalue Heatmap -- By {annotation_col}",
1219
+ cluster_rows=True, cluster_cols=True,
1220
+ width=width, height=height,
1221
+ text_format=".2f", font_size=10, margin_pad=150
1222
+ )
1223
+ fig.write_image(static_folder / f'median_pvalue_{annotation_col}.png', scale=2)
1224
+ fig.write_html(interactive_folder / f'median_pvalue_{annotation_col}.html')
1225
+
1226
+ # 4. Median p-value heatmap (normalized)
1227
+ fig = self._plot_p_cauchy_heatmap(
1228
+ df=table_median,
1229
+ title=f"Median log 10 pvalue Heatmap -- By {annotation_col}",
1230
+ normalize_axis='column',
1231
+ cluster_rows=True, cluster_cols=True,
1232
+ width=width, height=height,
1233
+ text_format=".2f", font_size=10, margin_pad=150
1234
+ )
1235
+ fig.write_image(static_folder / f'median_pvalue_{annotation_col}_normalized.png', scale=2)
1236
+ fig.write_html(interactive_folder / f'median_pvalue_{annotation_col}_normalized.html')
1237
+
1238
+ # 5. Odds ratio heatmap
1239
+ fig = self._plot_p_cauchy_heatmap(
1240
+ df=table_odds_ratio,
1241
+ title=f"Odds Ratio Heatmap -- By {annotation_col}",
1242
+ cluster_rows=True, cluster_cols=True,
1243
+ width=width, height=height,
1244
+ text_format=".2f", font_size=10, margin_pad=150
1245
+ )
1246
+ fig.write_image(static_folder / f'odds_ratio_{annotation_col}.png', scale=2)
1247
+ fig.write_html(interactive_folder / f'odds_ratio_{annotation_col}.html')
1248
+
1249
+ def _results_dict_to_log10_table(self, results_dict: dict, value_col: str = 'p_cauchy',
1250
+ log10_transform: bool = True, epsilon: float = 1e-300) -> pd.DataFrame:
1251
+ """Convert results dict to pivot table"""
1252
+ all_data = []
1253
+ for anno, df in results_dict.items():
1254
+ temp = df.copy()
1255
+ temp['annotation'] = anno
1256
+ all_data.append(temp)
1257
+
1258
+ combined_df = pd.concat(all_data, ignore_index=True)
1259
+
1260
+ if log10_transform:
1261
+ combined_df.loc[combined_df[value_col] == 0, value_col] = epsilon
1262
+ combined_df['transformed'] = -np.log10(combined_df[value_col])
1263
+ else:
1264
+ combined_df['transformed'] = combined_df[value_col]
1265
+
1266
+ pivot_df = combined_df.pivot(index='annotation', columns='trait', values='transformed')
1267
+ return pivot_df
1268
+
1269
+ def _plot_p_cauchy_heatmap(self, df: pd.DataFrame, title: str = "Cauchy Combination Heatmap",
1270
+ normalize_axis: Literal["row", "column"] | None = None,
1271
+ cluster_rows: bool = False, cluster_cols: bool = False,
1272
+ color_continuous_scale: str | list = "RdBu_r",
1273
+ width: int | None = None, height: int | None = None,
1274
+ text_format: str = ".2f",
1275
+ show_text: bool = True, font_size: int = 10, margin_pad: int = 150) -> go.Figure:
1276
+ """
1277
+ Create an enhanced heatmap visualization for trait-annotation relationships.
1278
+ """
1279
+ data = df.copy()
1280
+
1281
+ # Input validation
1282
+ if not isinstance(data, pd.DataFrame):
1283
+ raise TypeError("Input must be a pandas DataFrame")
1284
+ if data.empty:
1285
+ raise ValueError("Input DataFrame is empty")
1286
+ if not np.issubdtype(data.values.dtype, np.number):
1287
+ raise ValueError("DataFrame must contain numeric values")
1288
+
1289
+ n_rows, n_cols = data.shape
1290
+ # Set dynamic width/height if not provided to ensure good aspect ratio
1291
+ # Previously we used 50 and 30, which led to vertical stretching.
1292
+ # Let's use more balanced units.
1293
+ if width is None:
1294
+ width = max(600, n_cols * 150 + margin_pad * 2)
1295
+ if height is None:
1296
+ height = max(500, n_rows * 60 + margin_pad * 2)
1297
+
1298
+ # Normalization with error handling
1299
+ if normalize_axis in ['row', 'column']:
1300
+ axis = 1 if normalize_axis == 'row' else 0
1301
+ try:
1302
+ # Store original data for text annotations
1303
+ original_data = data.copy()
1304
+
1305
+ # Calculate min and max along specified axis
1306
+ min_vals = data.min(axis=axis)
1307
+ max_vals = data.max(axis=axis)
1308
+ range_vals = max_vals - min_vals
1309
+
1310
+ # Replace zero range with 1 to avoid division by zero
1311
+ range_vals = range_vals.replace(0, 1)
1312
+
1313
+ # Normalize using broadcasting
1314
+ if normalize_axis == 'row':
1315
+ data = data.sub(min_vals, axis=0).div(range_vals, axis=0)
1316
+ else: # column
1317
+ data = data.sub(min_vals, axis=1).div(range_vals, axis=1)
1318
+
1319
+ data = data.fillna(0)
1320
+ except Exception as e:
1321
+ raise ValueError(f"Normalization failed: {str(e)}")
1322
+ else:
1323
+ # No normalization, use original data for both color and text
1324
+ original_data = data
1325
+
1326
+ # Clustering with error handling
1327
+ try:
1328
+ if cluster_rows:
1329
+ row_linkage = linkage(data.fillna(0).values, method='average', metric='euclidean')
1330
+ row_order = leaves_list(row_linkage)
1331
+ data = data.iloc[row_order, :]
1332
+ original_data = original_data.iloc[row_order, :] # Apply the same order to original data
1333
+
1334
+ if cluster_cols:
1335
+ col_linkage = linkage(data.fillna(0).values.T, method='average', metric='euclidean')
1336
+ col_order = leaves_list(col_linkage)
1337
+ data = data.iloc[:, col_order]
1338
+ original_data = original_data.iloc[:, col_order] # Apply the same order to original data
1339
+ except Exception as e:
1340
+ raise ValueError(f"Clustering failed: {str(e)}")
1341
+
1342
+ # Create heatmap with enhanced formatting
1343
+ if normalize_axis is None:
1344
+ # Use original settings for speed when no normalization is applied
1345
+ fig = px.imshow(
1346
+ data,
1347
+ color_continuous_scale=color_continuous_scale,
1348
+ aspect='auto',
1349
+ width=width,
1350
+ height=height,
1351
+ text_auto=text_format if show_text else False # Automatic text generation
1352
+ )
1353
+ else:
1354
+ # Use custom logic for normalization (manual text annotations)
1355
+ fig = px.imshow(
1356
+ data,
1357
+ color_continuous_scale=color_continuous_scale,
1358
+ aspect='auto',
1359
+ width=width,
1360
+ height=height,
1361
+ text_auto=False # Disable automatic text generation
1362
+ )
1363
+
1364
+ # Add manual text annotations using original data
1365
+ if show_text:
1366
+ for i, row in enumerate(original_data.values):
1367
+ for j, value in enumerate(row):
1368
+ fig.add_annotation(
1369
+ x=j,
1370
+ y=i,
1371
+ text=f"{value:{text_format}}",
1372
+ showarrow=False,
1373
+ font=dict(size=font_size, color='black')
1374
+ )
1375
+
1376
+ # Enhanced layout configuration
1377
+ fig.update_layout(
1378
+ title={
1379
+ 'text': title,
1380
+ 'y': 0.98,
1381
+ 'x': 0.5,
1382
+ 'xanchor': 'center',
1383
+ 'yanchor': 'bottom',
1384
+ 'font': {'size': font_size + 4}
1385
+ },
1386
+ xaxis={
1387
+ 'title': "Trait",
1388
+ 'tickangle': 45,
1389
+ 'side': 'bottom',
1390
+ 'tickfont': {'size': font_size},
1391
+ 'title_font': {'size': font_size + 2}
1392
+ },
1393
+ yaxis={
1394
+ 'title': "Annotation",
1395
+ 'tickfont': {'size': font_size},
1396
+ 'title_font': {'size': font_size + 2}
1397
+ },
1398
+ width=width,
1399
+ height=height,
1400
+ template='plotly_white',
1401
+ margin=dict(l=margin_pad, r=margin_pad, t=margin_pad, b=margin_pad),
1402
+ coloraxis_colorbar={
1403
+ 'title': '-log10(p)' if not normalize_axis else 'Normalized Value',
1404
+ 'title_side': 'right',
1405
+ 'title_font': {'size': font_size + 2},
1406
+ 'tickfont': {'size': font_size}
1407
+ }
1408
+ )
1409
+ return fig