gsMap 1.71.1__py3-none-any.whl → 1.72.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,154 @@
1
+ import logging
2
+ from pathlib import Path
3
+
4
+ import anndata
5
+ import numpy as np
6
+ import pandas as pd
7
+ import scanpy as sc
8
+ import zarr
9
+ from scipy.stats import rankdata
10
+ from tqdm import tqdm
11
+
12
+ from gsMap.config import CreateSliceMeanConfig
13
+
14
+ # %% Helper functions
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ def get_common_genes(h5ad_files, config: CreateSliceMeanConfig):
19
+ """
20
+ Get common genes from a list of h5ad files.
21
+ """
22
+ common_genes = None
23
+ for file in tqdm(h5ad_files, desc="Finding common genes"):
24
+ adata = sc.read_h5ad(file)
25
+ adata.var_names_make_unique()
26
+ if common_genes is None:
27
+ common_genes = adata.var_names
28
+ else:
29
+ common_genes = common_genes.intersection(adata.var_names)
30
+ # sort
31
+
32
+ if config.species is not None:
33
+ homologs = pd.read_csv(config.homolog_file, sep="\t")
34
+ if homologs.shape[1] < 2:
35
+ raise ValueError(
36
+ "Homologs file must have at least two columns: one for the species and one for the human gene symbol."
37
+ )
38
+ homologs.columns = [config.species, "HUMAN_GENE_SYM"]
39
+ homologs.set_index(config.species, inplace=True)
40
+ common_genes = np.intersect1d(common_genes, homologs.index)
41
+
42
+ common_genes = sorted(common_genes)
43
+ return common_genes
44
+
45
+
46
+ def calculate_one_slice_mean(
47
+ sample_name, file_path: Path, common_genes, zarr_group_path, data_layer
48
+ ):
49
+ """
50
+ Calculate the geometric mean (using log trick) of gene expressions for a single slice and store it in a Zarr group.
51
+ """
52
+ # file_name = file_path.name
53
+ gmean_zarr_group = zarr.open(zarr_group_path, mode="a")
54
+ adata = anndata.read_h5ad(file_path)
55
+
56
+ if data_layer in adata.layers.keys():
57
+ adata.X = adata.layers[data_layer]
58
+ elif data_layer == "X":
59
+ pass
60
+ else:
61
+ raise ValueError(f"Data layer {data_layer} not found in {file_path}")
62
+
63
+ adata = adata[:, common_genes].copy()
64
+ n_cells = adata.shape[0]
65
+ log_ranks = np.zeros((n_cells, adata.n_vars), dtype=np.float32)
66
+ # Compute log of ranks to avoid overflow when computing geometric mean
67
+ for i in tqdm(range(n_cells), desc=f"Computing log ranks for {sample_name}"):
68
+ data = adata.X[i, :].toarray().flatten()
69
+ ranks = rankdata(data, method="average")
70
+ log_ranks[i, :] = np.log(ranks) # Adding small value to avoid log(0)
71
+
72
+ # Calculate geometric mean via log trick: exp(mean(log(values)))
73
+ gmean = (np.exp(np.mean(log_ranks, axis=0))).reshape(-1, 1)
74
+
75
+ # Calculate the expression fractio
76
+ adata_X_bool = adata.X.astype(bool)
77
+ frac = (np.asarray(adata_X_bool.sum(axis=0)).flatten()).reshape(-1, 1)
78
+
79
+ # Save to zarr group
80
+ gmean_frac = np.concatenate([gmean, frac], axis=1)
81
+ s1_zarr = gmean_zarr_group.array(sample_name, data=gmean_frac, chunks=None, dtype="f4")
82
+ s1_zarr.attrs["spot_number"] = adata.shape[0]
83
+
84
+
85
+ def merge_zarr_means(zarr_group_path, output_file, common_genes):
86
+ """
87
+ Merge all Zarr arrays into a weighted geometric mean and save to a Parquet file.
88
+ Instead of calculating the mean, it sums the logs and applies the exponential.
89
+ """
90
+ gmean_zarr_group = zarr.open(zarr_group_path, mode="a")
91
+ log_sum = None
92
+ frac_sum = None
93
+ total_spot_number = 0
94
+ for key in tqdm(gmean_zarr_group.array_keys(), desc="Merging Zarr arrays"):
95
+ s1 = gmean_zarr_group[key]
96
+ s1_array_gmean = s1[:][:, 0]
97
+ s1_array_frac = s1[:][:, 1]
98
+ n = s1.attrs["spot_number"]
99
+
100
+ if log_sum is None:
101
+ log_sum = np.log(s1_array_gmean) * n
102
+ frac_sum = s1_array_frac
103
+ else:
104
+ log_sum += np.log(s1_array_gmean) * n
105
+ frac_sum += s1_array_frac
106
+
107
+ total_spot_number += n
108
+
109
+ # Apply the geometric mean via exponentiation of the averaged logs
110
+ final_mean = np.exp(log_sum / total_spot_number)
111
+ final_frac = frac_sum / total_spot_number
112
+
113
+ # Save the final mean to a Parquet file
114
+ gene_names = common_genes
115
+ final_df = pd.DataFrame({"gene": gene_names, "G_Mean": final_mean, "frac": final_frac})
116
+ final_df.set_index("gene", inplace=True)
117
+ final_df.to_parquet(output_file)
118
+ return final_df
119
+
120
+
121
+ def run_create_slice_mean(config: CreateSliceMeanConfig):
122
+ """
123
+ Main entrypoint to create slice means.
124
+ Now works with a config that can accept either:
125
+ 1. An h5ad_yaml file.
126
+ 2. Direct lists of sample names and h5ad files.
127
+ """
128
+ h5ad_files = list(config.h5ad_dict.values())
129
+
130
+ # Step 2: Get common genes from the h5ad files
131
+ common_genes = get_common_genes(h5ad_files, config)
132
+ logger.info(f"Found {len(common_genes)} common genes across all files.")
133
+
134
+ # Step 3: Initialize the Zarr group
135
+ zarr_group_path = config.slice_mean_output_file.with_suffix(".zarr")
136
+
137
+ for sample_name, h5ad_file in config.h5ad_dict.items():
138
+ # Step 4: Process each file to calculate the slice means
139
+ if zarr_group_path.exists():
140
+ zarr_group = zarr.open(zarr_group_path.as_posix(), mode="r")
141
+ # Check if the slice mean for this file already exists
142
+ if sample_name in zarr_group.array_keys():
143
+ logger.info(f"Skipping {sample_name}, already processed.")
144
+ continue
145
+
146
+ calculate_one_slice_mean(
147
+ sample_name, h5ad_file, common_genes, zarr_group_path, config.data_layer
148
+ )
149
+
150
+ output_file = config.slice_mean_output_file
151
+ final_df = merge_zarr_means(zarr_group_path, output_file, common_genes)
152
+
153
+ logger.info(f"Final slice mean and expression fraction saved to {output_file}")
154
+ return final_df