ONTraC 2.0rc7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. ONTraC/GNN/_GNN.py +206 -0
  2. ONTraC/GNN/__init__.py +1 -0
  3. ONTraC/__init__.py +3 -0
  4. ONTraC/analysis/__init__.py +0 -0
  5. ONTraC/analysis/cell_type.py +605 -0
  6. ONTraC/analysis/constants.py +2 -0
  7. ONTraC/analysis/data.py +401 -0
  8. ONTraC/analysis/niche_cluster.py +679 -0
  9. ONTraC/analysis/niche_net.py +137 -0
  10. ONTraC/analysis/spatial.py +513 -0
  11. ONTraC/analysis/train_loss.py +43 -0
  12. ONTraC/analysis/utils.py +62 -0
  13. ONTraC/bin/NicheTrajectory.py +40 -0
  14. ONTraC/bin/ONTraC.py +42 -0
  15. ONTraC/bin/ONTraC_GNN.py +36 -0
  16. ONTraC/bin/ONTraC_GP.py +47 -0
  17. ONTraC/bin/ONTraC_GT.py +43 -0
  18. ONTraC/bin/ONTraC_NN.py +36 -0
  19. ONTraC/bin/ONTraC_NT.py +36 -0
  20. ONTraC/bin/ONTraC_analysis.py +82 -0
  21. ONTraC/bin/__init__.py +0 -0
  22. ONTraC/constants.py +11 -0
  23. ONTraC/data.py +101 -0
  24. ONTraC/external/STdeconvolve.R +52 -0
  25. ONTraC/external/__init__.py +0 -0
  26. ONTraC/external/deconvolution.py +61 -0
  27. ONTraC/integrate/__init__.py +1 -0
  28. ONTraC/integrate/general_control.py +378 -0
  29. ONTraC/log.py +91 -0
  30. ONTraC/model/__init__.py +1 -0
  31. ONTraC/model/_model.py +157 -0
  32. ONTraC/model/dmon_exp_pool.py +168 -0
  33. ONTraC/model/norm_dense_gcn_conv.py +88 -0
  34. ONTraC/niche_net/__init__.py +1 -0
  35. ONTraC/niche_net/_niche_net.py +335 -0
  36. ONTraC/niche_trajectory/__init__.py +1 -0
  37. ONTraC/niche_trajectory/_niche_trajectory.py +173 -0
  38. ONTraC/niche_trajectory/algorithm.py +82 -0
  39. ONTraC/optparser/_IO.py +306 -0
  40. ONTraC/optparser/_NN.py +113 -0
  41. ONTraC/optparser/_NT.py +59 -0
  42. ONTraC/optparser/__init__.py +7 -0
  43. ONTraC/optparser/_analysis.py +120 -0
  44. ONTraC/optparser/_preprocessing.py +90 -0
  45. ONTraC/optparser/_train.py +234 -0
  46. ONTraC/optparser/command.py +407 -0
  47. ONTraC/preprocessing/__inti__.py +8 -0
  48. ONTraC/preprocessing/data.py +112 -0
  49. ONTraC/preprocessing/expression.py +117 -0
  50. ONTraC/preprocessing/pp_control.py +302 -0
  51. ONTraC/run/__init__.py +1 -0
  52. ONTraC/run/processes.py +164 -0
  53. ONTraC/train/__init__.py +1 -0
  54. ONTraC/train/_batch_train.py +237 -0
  55. ONTraC/train/inspect_funcs.py +190 -0
  56. ONTraC/train/loss_funs.py +130 -0
  57. ONTraC/utils/__init__.py +1 -0
  58. ONTraC/utils/_utils.py +112 -0
  59. ONTraC/utils/decorators.py +96 -0
  60. ONTraC/version.py +1 -0
  61. ONTraC-2.0rc7.dist-info/LICENSE +21 -0
  62. ONTraC-2.0rc7.dist-info/METADATA +132 -0
  63. ONTraC-2.0rc7.dist-info/RECORD +66 -0
  64. ONTraC-2.0rc7.dist-info/WHEEL +5 -0
  65. ONTraC-2.0rc7.dist-info/entry_points.txt +10 -0
  66. ONTraC-2.0rc7.dist-info/top_level.txt +1 -0
ONTraC/GNN/_GNN.py ADDED
@@ -0,0 +1,206 @@
1
+ import random
2
+ from pathlib import Path
3
+ from typing import Callable, Dict, List, Optional, Tuple, Type, Union
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import torch
8
+ from scipy.sparse import load_npz
9
+ from torch import Tensor
10
+ from torch_geometric.loader import DenseDataLoader
11
+
12
+ from ..data import SpatailOmicsDataset
13
+ from ..log import info
14
+ from ..train import SubBatchTrainProtocol
15
+
16
+
17
+ def set_seed(seed: int) -> None:
18
+ """
19
+ Set seed.
20
+ :param seed: seed.
21
+ :return: None.
22
+ """
23
+
24
+ random.seed(seed)
25
+ torch.manual_seed(seed)
26
+ np.random.seed(seed)
27
+
28
+
29
+ def train(nn_model: torch.nn.Module,
30
+ BatchTrain: Type[SubBatchTrainProtocol],
31
+ sample_loader: DenseDataLoader,
32
+ device: torch.device,
33
+ max_epochs: int,
34
+ max_patience: int,
35
+ min_delta: float,
36
+ min_epochs: int,
37
+ lr: float,
38
+ save_dir: Union[str, Path],
39
+ inspect_funcs: Optional[List[Callable]] = None,
40
+ **kwargs) -> SubBatchTrainProtocol:
41
+ """
42
+ GNN training process.
43
+ :param nn_model: nn model.
44
+ :param BatchTrain: Type[SubBatchTrainProtocol], batch train.
45
+ :param sample_loader: DenseDataLoader, sample loader.
46
+ :param device: torch.device, device.
47
+ :param max_epochs: int, max epochs.
48
+ :param max_patience: int, max patience.
49
+ :param min_delta: float, min delta.
50
+ :param min_epochs: int, min epochs.
51
+ :param lr: float, learning rate.
52
+ :param save_dir: Union[str, Path], save directory.
53
+ :param inspect_funcs: Optional[List[Callable]], inspect functions.
54
+ :param kwargs: dict, loss weight arguments.
55
+ """
56
+ optimizer = torch.optim.Adam(nn_model.parameters(), lr=lr)
57
+ batch_train = BatchTrain(model=nn_model, device=torch.device(device), data_loader=sample_loader) # type: ignore
58
+ batch_train.save(path=f'{save_dir}/epoch_0.pt')
59
+
60
+ loss_weight_args: Dict[str, float] = {key: value for key, value in kwargs.items() if key.endswith('loss_weight')}
61
+
62
+ batch_train.train(optimizer=optimizer,
63
+ inspect_funcs=inspect_funcs,
64
+ max_epochs=max_epochs,
65
+ max_patience=max_patience,
66
+ min_delta=min_delta,
67
+ min_epochs=min_epochs,
68
+ output=save_dir,
69
+ **loss_weight_args)
70
+ batch_train.save(path=f'{save_dir}/model_state_dict.pt')
71
+ info(message=f'Training process end.')
72
+ return batch_train
73
+
74
+
75
+ def evaluate(batch_train: SubBatchTrainProtocol) -> None:
76
+ """
77
+ Evaluate the performance of ONTraC model on data.
78
+ :param batch_train: SubBatchTrainProtocol, batch train.
79
+ :return
80
+ """
81
+ info(message=f'Evaluating process start.')
82
+ loss_dict: Dict[str, np.floating] = batch_train.evaluate() # type: ignore
83
+ info(message=f'Evaluate loss, {repr(loss_dict)}')
84
+ info(message=f'Evaluating process end.')
85
+
86
+
87
+ def predict(output_dir: str, batch_train: SubBatchTrainProtocol,
88
+ dataset: SpatailOmicsDataset) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
89
+ """
90
+ Predict the results of ONTraC model on data.
91
+ :param output_dir: str, output directory.
92
+ :param batch_train: SubBatchTrainProtocol, batch train.
93
+ :param dataset: SpatailOmicsDataset, dataset.
94
+ :return: consolidate_s_array, consolidate_out_adj_array.
95
+ """
96
+ info(f'Predicting process start.')
97
+ each_sample_loader = DenseDataLoader(dataset, batch_size=1)
98
+ consolidate_flag = False
99
+ consolidate_s_list = []
100
+ consolidate_out = None
101
+ consolidate_out_adj = None
102
+ for data in each_sample_loader: # type: ignore
103
+ info(f'Generating prediction results for {data.name[0]}.')
104
+ data = data.to(batch_train.device) # type: ignore
105
+ predict_result = batch_train.predict_dict(data=data) # type: ignore
106
+ for key, value in predict_result.items():
107
+ np.savetxt(fname=f'{output_dir}/{data.name[0]}_{key}.csv.gz',
108
+ X=value.squeeze(0).detach().cpu().numpy(),
109
+ delimiter=',')
110
+
111
+ # consolidate results
112
+ if not consolidate_flag and ('s' in predict_result and 'out' in predict_result and 'out_adj' in predict_result):
113
+ consolidate_flag = True
114
+ if consolidate_flag:
115
+ s = predict_result['s']
116
+ out = predict_result['out']
117
+ s = s.squeeze(0)
118
+ consolidate_s_list.append(s)
119
+ out_adj_ = torch.matmul(torch.matmul(s.T, data.adj.squeeze(0)), s)
120
+ consolidate_out_adj: Tensor = out_adj_ if consolidate_out_adj is None else consolidate_out_adj + out_adj_
121
+ consolidate_out: Tensor = out.squeeze(
122
+ 0) * data.mask.sum() if consolidate_out is None else consolidate_out + out.squeeze(0) * data.mask.sum()
123
+
124
+ consolidate_s_array, consolidate_out_adj_array = None, None
125
+ if consolidate_flag:
126
+ # consolidate out
127
+ nodes_num = 0
128
+ for data in each_sample_loader: # type: ignore
129
+ nodes_num += data.mask.sum()
130
+ consolidate_out = consolidate_out / nodes_num # type: ignore
131
+ consolidate_out_array = consolidate_out.detach().cpu().numpy() # type: ignore
132
+ np.savetxt(fname=f'{output_dir}/consolidate_out.csv.gz', X=consolidate_out_array, delimiter=',')
133
+ # consolidate s
134
+ consolidate_s = torch.cat(consolidate_s_list, dim=0)
135
+ # consolidate out_adj
136
+ ind = torch.arange(consolidate_s.shape[-1], device=consolidate_out_adj.device) # type: ignore
137
+ consolidate_out_adj[ind, ind] = 0 # type: ignore
138
+ d = torch.einsum('ij->i', consolidate_out_adj)
139
+ d = torch.sqrt(d)[:, None] + 1e-15
140
+ consolidate_out_adj = (consolidate_out_adj / d) / d.transpose(0, 1)
141
+ consolidate_s_array = consolidate_s.detach().cpu().numpy()
142
+ consolidate_out_adj_array = consolidate_out_adj.detach().cpu().numpy()
143
+ np.savetxt(fname=f'{output_dir}/consolidate_s.csv.gz', X=consolidate_s_array, delimiter=',')
144
+ np.savetxt(fname=f'{output_dir}/consolidate_out_adj.csv.gz', X=consolidate_out_adj_array, delimiter=',')
145
+
146
+ info(f'Predicting process end.')
147
+ return consolidate_s_array, consolidate_out_adj_array
148
+
149
+
150
+ def save_graph_pooling_results(meta_data_df: pd.DataFrame, dataset: SpatailOmicsDataset, rel_params: Dict,
151
+ consolidate_s_array: np.ndarray, output_dir: str) -> None:
152
+ """
153
+ Save graph pooling results as the Niche cluster (max probability for each niche & cell).
154
+ :param meta_data_df: pd.DataFrame, original data. Sample and Cell_ID columns are used.
155
+ :param dataset: SpatailOmicsDataset, dataset.
156
+ :param rel_params: dict, relative parameters.
157
+ :param consolidate_s_array: np.ndarray, consolidate s array.
158
+ :param output_dir: str, output directory.
159
+ :return: None.
160
+ """
161
+
162
+ id_name: str = meta_data_df.columns[0]
163
+
164
+ consolidate_s_niche_df = pd.DataFrame()
165
+ consolidate_s_cell_df = pd.DataFrame()
166
+ for i, data in enumerate(dataset):
167
+ # the slice of data in each sample
168
+ slice_ = slice(i * data.x.shape[0], i * data.x.shape[0] + data.mask.sum())
169
+ consolidate_s = consolidate_s_array[slice_] # N x C
170
+ consolidate_s_df_ = pd.DataFrame(consolidate_s,
171
+ columns=[f'NicheCluster_{i}' for i in range(consolidate_s.shape[1])])
172
+ consolidate_s_df_[id_name] = meta_data_df[meta_data_df['Sample'] == data.name][id_name].values
173
+ consolidate_s_niche_df = pd.concat([consolidate_s_niche_df, consolidate_s_df_], axis=0)
174
+
175
+ # niche to cell matrix
176
+ niche_weight_matrix = load_npz(rel_params['Data'][i]['NicheWeightMatrix'])
177
+ niche_to_cell_matrix = (
178
+ niche_weight_matrix /
179
+ niche_weight_matrix.sum(axis=0)).T # normalize by the all niches associated with each cell, N x N
180
+
181
+ consolidate_s_cell = niche_to_cell_matrix @ consolidate_s
182
+ consolidate_s_cell_df_ = pd.DataFrame(consolidate_s_cell,
183
+ columns=[f'NicheCluster_{i}' for i in range(consolidate_s_cell.shape[1])])
184
+ consolidate_s_cell_df_[id_name] = meta_data_df[meta_data_df['Sample'] == data.name][id_name].values
185
+ consolidate_s_cell_df = pd.concat([consolidate_s_cell_df, consolidate_s_cell_df_], axis=0)
186
+
187
+ consolidate_s_niche_df = consolidate_s_niche_df.set_index(id_name)
188
+ consolidate_s_niche_df = consolidate_s_niche_df.loc[meta_data_df[id_name], :]
189
+ consolidate_s_niche_df.to_csv(f'{output_dir}/niche_level_niche_cluster.csv.gz',
190
+ index=True,
191
+ index_label=id_name,
192
+ header=True)
193
+ consolidate_s_niche_df['Niche_Cluster'] = consolidate_s_niche_df.values.argmax(axis=1)
194
+ consolidate_s_niche_df['Niche_Cluster'].to_csv(f'{output_dir}/niche_level_max_niche_cluster.csv.gz',
195
+ index=True,
196
+ header=True)
197
+ consolidate_s_cell_df = consolidate_s_cell_df.set_index(id_name)
198
+ consolidate_s_cell_df = consolidate_s_cell_df.loc[meta_data_df[id_name], :]
199
+ consolidate_s_cell_df.to_csv(f'{output_dir}/cell_level_niche_cluster.csv.gz',
200
+ index=True,
201
+ index_label=id_name,
202
+ header=True)
203
+ consolidate_s_cell_df['Niche_Cluster'] = consolidate_s_cell_df.values.argmax(axis=1)
204
+ consolidate_s_cell_df['Niche_Cluster'].to_csv(f'{output_dir}/cell_level_max_niche_cluster.csv.gz',
205
+ index=True,
206
+ header=True)
ONTraC/GNN/__init__.py ADDED
@@ -0,0 +1 @@
1
+ from ._GNN import *
ONTraC/__init__.py ADDED
@@ -0,0 +1,3 @@
1
+ from .version import __version__
2
+
3
+ citation = "Wang, W., Zheng, S., Shin, C. S. & Yuan, G. C. Characterizing Spatially Continuous Variations in Tissue Microenvironment through Niche Trajectory Analysis. https://www.biorxiv.org/content/10.1101/2024.04.23.590827v1. bioRxiv, 2024."
File without changes