gsMap 1.71.2__py3-none-any.whl → 1.72.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,14 +1,15 @@
1
1
  import numpy as np
2
2
  import pandas as pd
3
3
  import scipy.sparse as sp
4
- from sklearn.neighbors import NearestNeighbors
5
4
  import torch
5
+ from sklearn.neighbors import NearestNeighbors
6
+
6
7
 
7
8
  def cal_spatial_net(adata, n_neighbors=5, verbose=True):
8
9
  """Construct the spatial neighbor network."""
9
10
  if verbose:
10
- print('------Calculating spatial graph...')
11
- coor = pd.DataFrame(adata.obsm['spatial'], index=adata.obs.index)
11
+ print("------Calculating spatial graph...")
12
+ coor = pd.DataFrame(adata.obsm["spatial"], index=adata.obs.index)
12
13
  nbrs = NearestNeighbors(n_neighbors=n_neighbors).fit(coor)
13
14
  distances, indices = nbrs.kneighbors(coor)
14
15
  n_cells, n_neighbors = indices.shape
@@ -16,22 +17,22 @@ def cal_spatial_net(adata, n_neighbors=5, verbose=True):
16
17
  cell1 = np.repeat(cell_indices, n_neighbors)
17
18
  cell2 = indices.flatten()
18
19
  distance = distances.flatten()
19
- knn_df = pd.DataFrame({'Cell1': cell1, 'Cell2': cell2, 'Distance': distance})
20
- knn_df = knn_df[knn_df['Distance'] > 0].copy()
21
- cell_id_map = dict(zip(cell_indices, coor.index))
22
- knn_df['Cell1'] = knn_df['Cell1'].map(cell_id_map)
23
- knn_df['Cell2'] = knn_df['Cell2'].map(cell_id_map)
20
+ knn_df = pd.DataFrame({"Cell1": cell1, "Cell2": cell2, "Distance": distance})
21
+ knn_df = knn_df[knn_df["Distance"] > 0].copy()
22
+ cell_id_map = dict(zip(cell_indices, coor.index, strict=False))
23
+ knn_df["Cell1"] = knn_df["Cell1"].map(cell_id_map)
24
+ knn_df["Cell2"] = knn_df["Cell2"].map(cell_id_map)
24
25
  return knn_df
25
26
 
27
+
26
28
  def sparse_mx_to_torch_sparse_tensor(sparse_mx):
27
29
  """Convert a scipy sparse matrix to a torch sparse tensor."""
28
30
  sparse_mx = sparse_mx.tocoo().astype(np.float32)
29
- indices = torch.from_numpy(
30
- np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)
31
- )
31
+ indices = torch.from_numpy(np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
32
32
  values = torch.from_numpy(sparse_mx.data)
33
33
  shape = torch.Size(sparse_mx.shape)
34
- return torch.sparse_coo_tensor(indices, values, shape)
34
+ return torch.sparse_coo_tensor(indices, values, shape)
35
+
35
36
 
36
37
  def preprocess_graph(adj):
37
38
  """Symmetrically normalize the adjacency matrix."""
@@ -42,34 +43,31 @@ def preprocess_graph(adj):
42
43
  adj_normalized = adj_.dot(degree_mat_inv_sqrt).transpose().dot(degree_mat_inv_sqrt).tocoo()
43
44
  return sparse_mx_to_torch_sparse_tensor(adj_normalized)
44
45
 
46
+
45
47
  def construct_adjacency_matrix(adata, params, verbose=True):
46
48
  """Construct the adjacency matrix from spatial data."""
47
49
  spatial_net = cal_spatial_net(adata, n_neighbors=params.n_neighbors, verbose=verbose)
48
50
  if verbose:
49
51
  num_edges = spatial_net.shape[0]
50
52
  num_cells = adata.n_obs
51
- print(f'The graph contains {num_edges} edges, {num_cells} cells.')
52
- print(f'{num_edges / num_cells:.2f} neighbors per cell on average.')
53
+ print(f"The graph contains {num_edges} edges, {num_cells} cells.")
54
+ print(f"{num_edges / num_cells:.2f} neighbors per cell on average.")
53
55
  cell_ids = {cell: idx for idx, cell in enumerate(adata.obs.index)}
54
- spatial_net['Cell1'] = spatial_net['Cell1'].map(cell_ids)
55
- spatial_net['Cell2'] = spatial_net['Cell2'].map(cell_ids)
56
+ spatial_net["Cell1"] = spatial_net["Cell1"].map(cell_ids)
57
+ spatial_net["Cell2"] = spatial_net["Cell2"].map(cell_ids)
56
58
  if params.weighted_adj:
57
- distance_normalized = spatial_net['Distance'] / (spatial_net['Distance'].max() + 1)
58
- weights = np.exp(-0.5 * distance_normalized ** 2)
59
+ distance_normalized = spatial_net["Distance"] / (spatial_net["Distance"].max() + 1)
60
+ weights = np.exp(-0.5 * distance_normalized**2)
59
61
  adj_org = sp.coo_matrix(
60
- (weights, (spatial_net['Cell1'], spatial_net['Cell2'])),
61
- shape=(adata.n_obs, adata.n_obs)
62
+ (weights, (spatial_net["Cell1"], spatial_net["Cell2"])),
63
+ shape=(adata.n_obs, adata.n_obs),
62
64
  )
63
65
  else:
64
66
  adj_org = sp.coo_matrix(
65
- (np.ones(spatial_net.shape[0]), (spatial_net['Cell1'], spatial_net['Cell2'])),
66
- shape=(adata.n_obs, adata.n_obs)
67
+ (np.ones(spatial_net.shape[0]), (spatial_net["Cell1"], spatial_net["Cell2"])),
68
+ shape=(adata.n_obs, adata.n_obs),
67
69
  )
68
70
  adj_norm = preprocess_graph(adj_org)
69
71
  norm_value = adj_org.shape[0] ** 2 / ((adj_org.shape[0] ** 2 - adj_org.sum()) * 2)
70
- graph_dict = {
71
- "adj_org": adj_org,
72
- "adj_norm": adj_norm,
73
- "norm_value": norm_value
74
- }
72
+ graph_dict = {"adj_org": adj_org, "adj_norm": adj_norm, "norm_value": norm_value}
75
73
  return graph_dict
gsMap/GNN/model.py CHANGED
@@ -3,14 +3,16 @@ import torch.nn as nn
3
3
  import torch.nn.functional as F
4
4
  from torch_geometric.nn import GATConv
5
5
 
6
+
6
7
  def full_block(in_features, out_features, p_drop):
7
8
  return nn.Sequential(
8
9
  nn.Linear(in_features, out_features),
9
10
  nn.BatchNorm1d(out_features),
10
11
  nn.ELU(),
11
- nn.Dropout(p=p_drop)
12
+ nn.Dropout(p=p_drop),
12
13
  )
13
14
 
15
+
14
16
  class GATModel(nn.Module):
15
17
  def __init__(self, input_dim, params, num_classes=1):
16
18
  super().__init__()
@@ -21,7 +23,7 @@ class GATModel(nn.Module):
21
23
  # Encoder
22
24
  self.encoder = nn.Sequential(
23
25
  full_block(input_dim, params.feat_hidden1, params.p_drop),
24
- full_block(params.feat_hidden1, params.feat_hidden2, params.p_drop)
26
+ full_block(params.feat_hidden1, params.feat_hidden2, params.p_drop),
25
27
  )
26
28
 
27
29
  # GAT Layers
@@ -29,14 +31,14 @@ class GATModel(nn.Module):
29
31
  in_channels=params.feat_hidden2,
30
32
  out_channels=params.gat_hidden1,
31
33
  heads=params.nheads,
32
- dropout=params.p_drop
34
+ dropout=params.p_drop,
33
35
  )
34
36
  self.gat2 = GATConv(
35
37
  in_channels=params.gat_hidden1 * params.nheads,
36
38
  out_channels=params.gat_hidden2,
37
39
  heads=1,
38
40
  concat=False,
39
- dropout=params.p_drop
41
+ dropout=params.p_drop,
40
42
  )
41
43
  if self.var:
42
44
  self.gat3 = GATConv(
@@ -44,20 +46,20 @@ class GATModel(nn.Module):
44
46
  out_channels=params.gat_hidden2,
45
47
  heads=1,
46
48
  concat=False,
47
- dropout=params.p_drop
49
+ dropout=params.p_drop,
48
50
  )
49
51
 
50
52
  # Decoder
51
53
  self.decoder = nn.Sequential(
52
54
  full_block(params.gat_hidden2, params.feat_hidden2, params.p_drop),
53
55
  full_block(params.feat_hidden2, params.feat_hidden1, params.p_drop),
54
- nn.Linear(params.feat_hidden1, input_dim)
56
+ nn.Linear(params.feat_hidden1, input_dim),
55
57
  )
56
58
 
57
59
  # Clustering Layer
58
60
  self.cluster = nn.Sequential(
59
61
  full_block(params.gat_hidden2, params.feat_hidden2, params.p_drop),
60
- nn.Linear(params.feat_hidden2, self.num_classes)
62
+ nn.Linear(params.feat_hidden2, self.num_classes),
61
63
  )
62
64
 
63
65
  def encode(self, x, edge_index):
gsMap/GNN/train.py CHANGED
@@ -23,7 +23,7 @@ def label_loss(pred_label, true_label):
23
23
  class ModelTrainer:
24
24
  def __init__(self, node_x, graph_dict, params, label=None):
25
25
  """Initialize the ModelTrainer with data and hyperparameters."""
26
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
26
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
27
  self.params = params
28
28
  self.epochs = params.epochs
29
29
  self.node_x = torch.FloatTensor(node_x).to(self.device)
@@ -38,17 +38,15 @@ class ModelTrainer:
38
38
  # Set up the model
39
39
  self.model = GATModel(self.params.feat_cell, self.params, self.num_classes).to(self.device)
40
40
  self.optimizer = torch.optim.Adam(
41
- self.model.parameters(),
42
- lr=self.params.gat_lr,
43
- weight_decay=self.params.gcn_decay
41
+ self.model.parameters(), lr=self.params.gat_lr, weight_decay=self.params.gcn_decay
44
42
  )
45
43
 
46
44
  def run_train(self):
47
45
  """Train the model."""
48
46
  self.model.train()
49
- prev_loss = float('inf')
50
- logger.info('Start training...')
51
- pbar = tqdm(range(self.epochs), desc='GAT-AE model train:', total=self.epochs)
47
+ prev_loss = float("inf")
48
+ logger.info("Start training...")
49
+ pbar = tqdm(range(self.epochs), desc="GAT-AE model train:", total=self.epochs)
52
50
  for epoch in range(self.epochs):
53
51
  start_time = time.time()
54
52
  self.optimizer.zero_grad()
@@ -67,18 +65,17 @@ class ModelTrainer:
67
65
  batch_time = time.time() - start_time
68
66
  left_time = batch_time * (self.epochs - epoch - 1) / 60 # in minutes
69
67
 
70
- pbar.set_postfix({'Left time': f'{left_time:.2f} mins', 'Loss': f'{loss.item():.4f}'})
68
+ pbar.set_postfix({"Left time": f"{left_time:.2f} mins", "Loss": f"{loss.item():.4f}"})
71
69
  pbar.update(1)
72
70
 
73
71
  if abs(loss.item() - prev_loss) <= self.params.convergence_threshold and epoch >= 200:
74
72
  pbar.close()
75
- logger.info('Convergence reached. Training stopped.')
73
+ logger.info("Convergence reached. Training stopped.")
76
74
  break
77
75
  prev_loss = loss.item()
78
76
  else:
79
77
  pbar.close()
80
- logger.info('Max epochs reached. Training stopped.')
81
-
78
+ logger.info("Max epochs reached. Training stopped.")
82
79
 
83
80
  def get_latent(self):
84
81
  """Retrieve the latent representation from the model."""
gsMap/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
- '''
1
+ """
2
2
  Genetics-informed pathogenic spatial mapping
3
- '''
3
+ """
4
4
 
5
- __version__ = '1.71.2'
5
+ __version__ = "1.72.3"
gsMap/__main__.py CHANGED
@@ -1,3 +1,4 @@
1
1
  from .main import main
2
- if __name__ == '__main__':
3
- main()
2
+
3
+ if __name__ == "__main__":
4
+ main()
@@ -10,9 +10,10 @@ from gsMap.config import CauchyCombinationConfig
10
10
 
11
11
  logger = logging.getLogger(__name__)
12
12
 
13
+
13
14
  # The fun of cauchy combination
14
15
  def acat_test(pvalues, weights=None):
15
- '''acat_test()
16
+ """acat_test()
16
17
  Aggregated Cauchy Assocaition Test
17
18
  A p-value combination method using the Cauchy distribution.
18
19
 
@@ -23,27 +24,28 @@ def acat_test(pvalues, weights=None):
23
24
  weights: <list or numpy array>, default=None
24
25
  The weights for each of the p-values. If None, equal weights are used.
25
26
 
26
- Returns:
27
+ Returns
28
+ -------
27
29
  pval: <float>
28
30
  The ACAT combined p-value.
29
- '''
31
+ """
30
32
  if any(np.isnan(pvalues)):
31
33
  raise Exception("Cannot have NAs in the p-values.")
32
- if any([(i > 1) | (i < 0) for i in pvalues]):
34
+ if any((i > 1) | (i < 0) for i in pvalues):
33
35
  raise Exception("P-values must be between 0 and 1.")
34
- if any([i == 1 for i in pvalues]) & any([i == 0 for i in pvalues]):
36
+ if any(i == 1 for i in pvalues) & any(i == 0 for i in pvalues):
35
37
  raise Exception("Cannot have both 0 and 1 p-values.")
36
- if any([i == 0 for i in pvalues]):
38
+ if any(i == 0 for i in pvalues):
37
39
  logger.info("Warn: p-values are exactly 0.")
38
40
  return 0
39
- if any([i == 1 for i in pvalues]):
41
+ if any(i == 1 for i in pvalues):
40
42
  logger.info("Warn: p-values are exactly 1.")
41
43
  return 1
42
- if weights == None:
44
+ if weights is None:
43
45
  weights = [1 / len(pvalues) for i in pvalues]
44
46
  elif len(weights) != len(pvalues):
45
47
  raise Exception("Length of weights and p-values differs.")
46
- elif any([i < 0 for i in weights]):
48
+ elif any(i < 0 for i in weights):
47
49
  raise Exception("All weights must be positive.")
48
50
  else:
49
51
  weights = [i / len(weights) for i in weights]
@@ -51,7 +53,7 @@ def acat_test(pvalues, weights=None):
51
53
  pvalues = np.array(pvalues)
52
54
  weights = np.array(weights)
53
55
 
54
- if any([i < 1e-16 for i in pvalues]) == False:
56
+ if not any(i < 1e-16 for i in pvalues):
55
57
  cct_stat = sum(weights * np.tan((0.5 - pvalues) * np.pi))
56
58
  else:
57
59
  is_small = [i < (1e-16) for i in pvalues]
@@ -67,75 +69,76 @@ def acat_test(pvalues, weights=None):
67
69
  return pval
68
70
 
69
71
 
70
- def run_Cauchy_combination(config:CauchyCombinationConfig):
71
- # Load the ldsc results
72
- logger.info(f'------Loading LDSC results of {config.ldsc_save_dir}...')
73
- ldsc_input_file= config.get_ldsc_result_file(config.trait_name)
74
- ldsc = pd.read_csv(ldsc_input_file, compression='gzip')
75
- ldsc.spot = ldsc.spot.astype(str).replace('\.0', '', regex=True)
76
- ldsc.index = ldsc.spot
77
- if config.meta is None:
78
- # Load the spatial data
79
- logger.info(f'------Loading ST data of {config.hdf5_with_latent_path}...')
80
- spe = sc.read_h5ad(f'{config.hdf5_with_latent_path}')
81
-
82
- common_cell = np.intersect1d(ldsc.index, spe.obs_names)
83
- spe = spe[common_cell]
84
- ldsc = ldsc.loc[common_cell]
85
-
86
- # Add the annotation
87
- ldsc['annotation'] = spe.obs.loc[ldsc.spot][config.annotation].to_list()
88
-
89
- elif config.meta is not None:
90
- # Or Load the additional annotation (just for the macaque data at this stage: 2023Nov25)
91
- logger.info(f'------Loading additional annotation...')
92
- meta = pd.read_csv(config.meta, index_col=0)
93
- meta = meta.loc[meta.slide == config.slide]
94
- meta.index = meta.cell_id.astype(str).replace('\.0', '', regex=True)
95
-
96
- common_cell = np.intersect1d(ldsc.index, meta.index)
97
- meta = meta.loc[common_cell]
98
- ldsc = ldsc.loc[common_cell]
99
-
100
- # Add the annotation
101
- ldsc['annotation'] = meta.loc[ldsc.spot][config.annotation].to_list()
102
- # Perform the Cauchy combination based on the given annotations
72
+ def run_Cauchy_combination(config: CauchyCombinationConfig):
73
+ ldsc_list = []
74
+
75
+ for sample_name in config.sample_name_list:
76
+ config.sample_name = sample_name
77
+
78
+ # Load the LDSC results for the current sample
79
+ logger.info(f"------Loading LDSC results for sample {sample_name}...")
80
+ ldsc_input_file = config.get_ldsc_result_file(
81
+ trait_name=config.trait_name,
82
+ )
83
+ ldsc = pd.read_csv(ldsc_input_file, compression="gzip")
84
+ ldsc["spot"] = ldsc["spot"].astype(str)
85
+ ldsc.index = ldsc["spot"]
86
+
87
+ # Load the spatial transcriptomics (ST) data for the current sample
88
+ logger.info(f"------Loading ST data for sample {sample_name}...")
89
+ h5ad_file = config.hdf5_with_latent_path
90
+ adata = sc.read_h5ad(h5ad_file)
91
+
92
+ # Identify common cells between LDSC results and ST data
93
+ common_cells = np.intersect1d(ldsc.index, adata.obs_names)
94
+ adata = adata[common_cells]
95
+ ldsc = ldsc.loc[common_cells]
96
+
97
+ # Add annotations to the LDSC dataframe
98
+ ldsc["annotation"] = adata.obs.loc[ldsc.spot, config.annotation].to_list()
99
+ ldsc_list.append(ldsc)
100
+
101
+ # Concatenate all LDSC dataframes from different samples
102
+ ldsc_all = pd.concat(ldsc_list)
103
+
104
+ # Run the Cauchy combination
103
105
  p_cauchy = []
104
106
  p_median = []
105
- for ct in np.unique(ldsc.annotation):
106
- p_temp = ldsc.loc[ldsc['annotation'] == ct, 'p']
107
-
108
- # The Cauchy test is sensitive to very small p-values, so extreme outliers should be considered for removal...
109
- # to enhance robustness, particularly in cases where spot annotations may be incorrect.
110
- # p_cauchy_temp = acat_test(p_temp[p_temp != np.min(p_temp)])
111
- p_temp_log = -np.log10(p_temp)
112
- median_log = np.median(p_temp_log)
113
- IQR_log = np.percentile(p_temp_log, 75) - np.percentile(p_temp_log, 25)
114
-
115
- p_use = p_temp[p_temp_log < median_log + 3*IQR_log]
116
- n_remove = len(p_temp) - len(p_use)
117
-
118
- # Outlier: -log10(p) < median + 3IQR && len(outlier set) < 20
119
- if (0 < n_remove < 20):
120
- logger.info(f'Remove {n_remove}/{len(p_temp)} outliers (median + 3IQR) for {ct}.')
121
- p_cauchy_temp = acat_test(p_use)
107
+ annotations = ldsc_all["annotation"].unique()
108
+
109
+ for ct in annotations:
110
+ p_values = ldsc_all.loc[ldsc_all["annotation"] == ct, "p"]
111
+
112
+ # Handle extreme outliers to enhance robustness
113
+ p_values_log = -np.log10(p_values)
114
+ median_log = np.median(p_values_log)
115
+ iqr_log = np.percentile(p_values_log, 75) - np.percentile(p_values_log, 25)
116
+
117
+ p_values_filtered = p_values[p_values_log < median_log + 3 * iqr_log]
118
+ n_removed = len(p_values) - len(p_values_filtered)
119
+
120
+ # Remove outliers if the number is reasonable
121
+ if 0 < n_removed < 20:
122
+ logger.info(f"Removed {n_removed}/{len(p_values)} outliers (median + 3IQR) for {ct}.")
123
+ p_cauchy_temp = acat_test(p_values_filtered)
122
124
  else:
123
- p_cauchy_temp = acat_test(p_temp)
124
-
125
- p_median_temp = np.median(p_temp)
125
+ p_cauchy_temp = acat_test(p_values)
126
126
 
127
+ p_median_temp = np.median(p_values)
127
128
  p_cauchy.append(p_cauchy_temp)
128
129
  p_median.append(p_median_temp)
129
- # p_tissue = pd.DataFrame(p_cauchy,p_median,np.unique(ldsc.annotation))
130
- data = {'p_cauchy': p_cauchy, 'p_median': p_median, 'annotation': np.unique(ldsc.annotation)}
131
- p_tissue = pd.DataFrame(data)
132
- p_tissue.columns = ['p_cauchy', 'p_median', 'annotation']
130
+
131
+ # Prepare the results dataframe
132
+ results = pd.DataFrame({"annotation": annotations, "p_cauchy": p_cauchy, "p_median": p_median})
133
+ results.sort_values(by="p_cauchy", inplace=True)
134
+
133
135
  # Save the results
134
- output_dir = Path(config.cauchy_save_dir)
135
- output_dir.mkdir(parents=True, exist_ok=True, mode=0o755)
136
- output_file = output_dir / f'{config.sample_name}_{config.trait_name}.Cauchy.csv.gz'
137
- p_tissue.to_csv(
136
+ Path(config.output_file).parent.mkdir(parents=True, exist_ok=True, mode=0o755)
137
+ output_file = Path(config.output_file)
138
+ results.to_csv(
138
139
  output_file,
139
- compression='gzip',
140
+ compression="gzip",
140
141
  index=False,
141
142
  )
143
+ logger.info(f"Cauchy combination results saved at {output_file}.")
144
+ return results