ONTraC 2.0rc7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ONTraC/GNN/_GNN.py +206 -0
- ONTraC/GNN/__init__.py +1 -0
- ONTraC/__init__.py +3 -0
- ONTraC/analysis/__init__.py +0 -0
- ONTraC/analysis/cell_type.py +605 -0
- ONTraC/analysis/constants.py +2 -0
- ONTraC/analysis/data.py +401 -0
- ONTraC/analysis/niche_cluster.py +679 -0
- ONTraC/analysis/niche_net.py +137 -0
- ONTraC/analysis/spatial.py +513 -0
- ONTraC/analysis/train_loss.py +43 -0
- ONTraC/analysis/utils.py +62 -0
- ONTraC/bin/NicheTrajectory.py +40 -0
- ONTraC/bin/ONTraC.py +42 -0
- ONTraC/bin/ONTraC_GNN.py +36 -0
- ONTraC/bin/ONTraC_GP.py +47 -0
- ONTraC/bin/ONTraC_GT.py +43 -0
- ONTraC/bin/ONTraC_NN.py +36 -0
- ONTraC/bin/ONTraC_NT.py +36 -0
- ONTraC/bin/ONTraC_analysis.py +82 -0
- ONTraC/bin/__init__.py +0 -0
- ONTraC/constants.py +11 -0
- ONTraC/data.py +101 -0
- ONTraC/external/STdeconvolve.R +52 -0
- ONTraC/external/__init__.py +0 -0
- ONTraC/external/deconvolution.py +61 -0
- ONTraC/integrate/__init__.py +1 -0
- ONTraC/integrate/general_control.py +378 -0
- ONTraC/log.py +91 -0
- ONTraC/model/__init__.py +1 -0
- ONTraC/model/_model.py +157 -0
- ONTraC/model/dmon_exp_pool.py +168 -0
- ONTraC/model/norm_dense_gcn_conv.py +88 -0
- ONTraC/niche_net/__init__.py +1 -0
- ONTraC/niche_net/_niche_net.py +335 -0
- ONTraC/niche_trajectory/__init__.py +1 -0
- ONTraC/niche_trajectory/_niche_trajectory.py +173 -0
- ONTraC/niche_trajectory/algorithm.py +82 -0
- ONTraC/optparser/_IO.py +306 -0
- ONTraC/optparser/_NN.py +113 -0
- ONTraC/optparser/_NT.py +59 -0
- ONTraC/optparser/__init__.py +7 -0
- ONTraC/optparser/_analysis.py +120 -0
- ONTraC/optparser/_preprocessing.py +90 -0
- ONTraC/optparser/_train.py +234 -0
- ONTraC/optparser/command.py +407 -0
- ONTraC/preprocessing/__inti__.py +8 -0
- ONTraC/preprocessing/data.py +112 -0
- ONTraC/preprocessing/expression.py +117 -0
- ONTraC/preprocessing/pp_control.py +302 -0
- ONTraC/run/__init__.py +1 -0
- ONTraC/run/processes.py +164 -0
- ONTraC/train/__init__.py +1 -0
- ONTraC/train/_batch_train.py +237 -0
- ONTraC/train/inspect_funcs.py +190 -0
- ONTraC/train/loss_funs.py +130 -0
- ONTraC/utils/__init__.py +1 -0
- ONTraC/utils/_utils.py +112 -0
- ONTraC/utils/decorators.py +96 -0
- ONTraC/version.py +1 -0
- ONTraC-2.0rc7.dist-info/LICENSE +21 -0
- ONTraC-2.0rc7.dist-info/METADATA +132 -0
- ONTraC-2.0rc7.dist-info/RECORD +66 -0
- ONTraC-2.0rc7.dist-info/WHEEL +5 -0
- ONTraC-2.0rc7.dist-info/entry_points.txt +10 -0
- ONTraC-2.0rc7.dist-info/top_level.txt +1 -0
ONTraC/GNN/_GNN.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
import random
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import pandas as pd
|
|
7
|
+
import torch
|
|
8
|
+
from scipy.sparse import load_npz
|
|
9
|
+
from torch import Tensor
|
|
10
|
+
from torch_geometric.loader import DenseDataLoader
|
|
11
|
+
|
|
12
|
+
from ..data import SpatailOmicsDataset
|
|
13
|
+
from ..log import info
|
|
14
|
+
from ..train import SubBatchTrainProtocol
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def set_seed(seed: int) -> None:
|
|
18
|
+
"""
|
|
19
|
+
Set seed.
|
|
20
|
+
:param seed: seed.
|
|
21
|
+
:return: None.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
random.seed(seed)
|
|
25
|
+
torch.manual_seed(seed)
|
|
26
|
+
np.random.seed(seed)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def train(nn_model: torch.nn.Module,
|
|
30
|
+
BatchTrain: Type[SubBatchTrainProtocol],
|
|
31
|
+
sample_loader: DenseDataLoader,
|
|
32
|
+
device: torch.device,
|
|
33
|
+
max_epochs: int,
|
|
34
|
+
max_patience: int,
|
|
35
|
+
min_delta: float,
|
|
36
|
+
min_epochs: int,
|
|
37
|
+
lr: float,
|
|
38
|
+
save_dir: Union[str, Path],
|
|
39
|
+
inspect_funcs: Optional[List[Callable]] = None,
|
|
40
|
+
**kwargs) -> SubBatchTrainProtocol:
|
|
41
|
+
"""
|
|
42
|
+
GNN training process.
|
|
43
|
+
:param nn_model: nn model.
|
|
44
|
+
:param BatchTrain: Type[SubBatchTrainProtocol], batch train.
|
|
45
|
+
:param sample_loader: DenseDataLoader, sample loader.
|
|
46
|
+
:param device: torch.device, device.
|
|
47
|
+
:param max_epochs: int, max epochs.
|
|
48
|
+
:param max_patience: int, max patience.
|
|
49
|
+
:param min_delta: float, min delta.
|
|
50
|
+
:param min_epochs: int, min epochs.
|
|
51
|
+
:param lr: float, learning rate.
|
|
52
|
+
:param save_dir: Union[str, Path], save directory.
|
|
53
|
+
:param inspect_funcs: Optional[List[Callable]], inspect functions.
|
|
54
|
+
:param kwargs: dict, loss weight arguments.
|
|
55
|
+
"""
|
|
56
|
+
optimizer = torch.optim.Adam(nn_model.parameters(), lr=lr)
|
|
57
|
+
batch_train = BatchTrain(model=nn_model, device=torch.device(device), data_loader=sample_loader) # type: ignore
|
|
58
|
+
batch_train.save(path=f'{save_dir}/epoch_0.pt')
|
|
59
|
+
|
|
60
|
+
loss_weight_args: Dict[str, float] = {key: value for key, value in kwargs.items() if key.endswith('loss_weight')}
|
|
61
|
+
|
|
62
|
+
batch_train.train(optimizer=optimizer,
|
|
63
|
+
inspect_funcs=inspect_funcs,
|
|
64
|
+
max_epochs=max_epochs,
|
|
65
|
+
max_patience=max_patience,
|
|
66
|
+
min_delta=min_delta,
|
|
67
|
+
min_epochs=min_epochs,
|
|
68
|
+
output=save_dir,
|
|
69
|
+
**loss_weight_args)
|
|
70
|
+
batch_train.save(path=f'{save_dir}/model_state_dict.pt')
|
|
71
|
+
info(message=f'Training process end.')
|
|
72
|
+
return batch_train
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def evaluate(batch_train: SubBatchTrainProtocol) -> None:
|
|
76
|
+
"""
|
|
77
|
+
Evaluate the performance of ONTraC model on data.
|
|
78
|
+
:param batch_train: SubBatchTrainProtocol, batch train.
|
|
79
|
+
:return
|
|
80
|
+
"""
|
|
81
|
+
info(message=f'Evaluating process start.')
|
|
82
|
+
loss_dict: Dict[str, np.floating] = batch_train.evaluate() # type: ignore
|
|
83
|
+
info(message=f'Evaluate loss, {repr(loss_dict)}')
|
|
84
|
+
info(message=f'Evaluating process end.')
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def predict(output_dir: str, batch_train: SubBatchTrainProtocol,
|
|
88
|
+
dataset: SpatailOmicsDataset) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
|
|
89
|
+
"""
|
|
90
|
+
Predict the results of ONTraC model on data.
|
|
91
|
+
:param output_dir: str, output directory.
|
|
92
|
+
:param batch_train: SubBatchTrainProtocol, batch train.
|
|
93
|
+
:param dataset: SpatailOmicsDataset, dataset.
|
|
94
|
+
:return: consolidate_s_array, consolidate_out_adj_array.
|
|
95
|
+
"""
|
|
96
|
+
info(f'Predicting process start.')
|
|
97
|
+
each_sample_loader = DenseDataLoader(dataset, batch_size=1)
|
|
98
|
+
consolidate_flag = False
|
|
99
|
+
consolidate_s_list = []
|
|
100
|
+
consolidate_out = None
|
|
101
|
+
consolidate_out_adj = None
|
|
102
|
+
for data in each_sample_loader: # type: ignore
|
|
103
|
+
info(f'Generating prediction results for {data.name[0]}.')
|
|
104
|
+
data = data.to(batch_train.device) # type: ignore
|
|
105
|
+
predict_result = batch_train.predict_dict(data=data) # type: ignore
|
|
106
|
+
for key, value in predict_result.items():
|
|
107
|
+
np.savetxt(fname=f'{output_dir}/{data.name[0]}_{key}.csv.gz',
|
|
108
|
+
X=value.squeeze(0).detach().cpu().numpy(),
|
|
109
|
+
delimiter=',')
|
|
110
|
+
|
|
111
|
+
# consolidate results
|
|
112
|
+
if not consolidate_flag and ('s' in predict_result and 'out' in predict_result and 'out_adj' in predict_result):
|
|
113
|
+
consolidate_flag = True
|
|
114
|
+
if consolidate_flag:
|
|
115
|
+
s = predict_result['s']
|
|
116
|
+
out = predict_result['out']
|
|
117
|
+
s = s.squeeze(0)
|
|
118
|
+
consolidate_s_list.append(s)
|
|
119
|
+
out_adj_ = torch.matmul(torch.matmul(s.T, data.adj.squeeze(0)), s)
|
|
120
|
+
consolidate_out_adj: Tensor = out_adj_ if consolidate_out_adj is None else consolidate_out_adj + out_adj_
|
|
121
|
+
consolidate_out: Tensor = out.squeeze(
|
|
122
|
+
0) * data.mask.sum() if consolidate_out is None else consolidate_out + out.squeeze(0) * data.mask.sum()
|
|
123
|
+
|
|
124
|
+
consolidate_s_array, consolidate_out_adj_array = None, None
|
|
125
|
+
if consolidate_flag:
|
|
126
|
+
# consolidate out
|
|
127
|
+
nodes_num = 0
|
|
128
|
+
for data in each_sample_loader: # type: ignore
|
|
129
|
+
nodes_num += data.mask.sum()
|
|
130
|
+
consolidate_out = consolidate_out / nodes_num # type: ignore
|
|
131
|
+
consolidate_out_array = consolidate_out.detach().cpu().numpy() # type: ignore
|
|
132
|
+
np.savetxt(fname=f'{output_dir}/consolidate_out.csv.gz', X=consolidate_out_array, delimiter=',')
|
|
133
|
+
# consolidate s
|
|
134
|
+
consolidate_s = torch.cat(consolidate_s_list, dim=0)
|
|
135
|
+
# consolidate out_adj
|
|
136
|
+
ind = torch.arange(consolidate_s.shape[-1], device=consolidate_out_adj.device) # type: ignore
|
|
137
|
+
consolidate_out_adj[ind, ind] = 0 # type: ignore
|
|
138
|
+
d = torch.einsum('ij->i', consolidate_out_adj)
|
|
139
|
+
d = torch.sqrt(d)[:, None] + 1e-15
|
|
140
|
+
consolidate_out_adj = (consolidate_out_adj / d) / d.transpose(0, 1)
|
|
141
|
+
consolidate_s_array = consolidate_s.detach().cpu().numpy()
|
|
142
|
+
consolidate_out_adj_array = consolidate_out_adj.detach().cpu().numpy()
|
|
143
|
+
np.savetxt(fname=f'{output_dir}/consolidate_s.csv.gz', X=consolidate_s_array, delimiter=',')
|
|
144
|
+
np.savetxt(fname=f'{output_dir}/consolidate_out_adj.csv.gz', X=consolidate_out_adj_array, delimiter=',')
|
|
145
|
+
|
|
146
|
+
info(f'Predicting process end.')
|
|
147
|
+
return consolidate_s_array, consolidate_out_adj_array
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def save_graph_pooling_results(meta_data_df: pd.DataFrame, dataset: SpatailOmicsDataset, rel_params: Dict,
|
|
151
|
+
consolidate_s_array: np.ndarray, output_dir: str) -> None:
|
|
152
|
+
"""
|
|
153
|
+
Save graph pooling results as the Niche cluster (max probability for each niche & cell).
|
|
154
|
+
:param meta_data_df: pd.DataFrame, original data. Sample and Cell_ID columns are used.
|
|
155
|
+
:param dataset: SpatailOmicsDataset, dataset.
|
|
156
|
+
:param rel_params: dict, relative parameters.
|
|
157
|
+
:param consolidate_s_array: np.ndarray, consolidate s array.
|
|
158
|
+
:param output_dir: str, output directory.
|
|
159
|
+
:return: None.
|
|
160
|
+
"""
|
|
161
|
+
|
|
162
|
+
id_name: str = meta_data_df.columns[0]
|
|
163
|
+
|
|
164
|
+
consolidate_s_niche_df = pd.DataFrame()
|
|
165
|
+
consolidate_s_cell_df = pd.DataFrame()
|
|
166
|
+
for i, data in enumerate(dataset):
|
|
167
|
+
# the slice of data in each sample
|
|
168
|
+
slice_ = slice(i * data.x.shape[0], i * data.x.shape[0] + data.mask.sum())
|
|
169
|
+
consolidate_s = consolidate_s_array[slice_] # N x C
|
|
170
|
+
consolidate_s_df_ = pd.DataFrame(consolidate_s,
|
|
171
|
+
columns=[f'NicheCluster_{i}' for i in range(consolidate_s.shape[1])])
|
|
172
|
+
consolidate_s_df_[id_name] = meta_data_df[meta_data_df['Sample'] == data.name][id_name].values
|
|
173
|
+
consolidate_s_niche_df = pd.concat([consolidate_s_niche_df, consolidate_s_df_], axis=0)
|
|
174
|
+
|
|
175
|
+
# niche to cell matrix
|
|
176
|
+
niche_weight_matrix = load_npz(rel_params['Data'][i]['NicheWeightMatrix'])
|
|
177
|
+
niche_to_cell_matrix = (
|
|
178
|
+
niche_weight_matrix /
|
|
179
|
+
niche_weight_matrix.sum(axis=0)).T # normalize by the all niches associated with each cell, N x N
|
|
180
|
+
|
|
181
|
+
consolidate_s_cell = niche_to_cell_matrix @ consolidate_s
|
|
182
|
+
consolidate_s_cell_df_ = pd.DataFrame(consolidate_s_cell,
|
|
183
|
+
columns=[f'NicheCluster_{i}' for i in range(consolidate_s_cell.shape[1])])
|
|
184
|
+
consolidate_s_cell_df_[id_name] = meta_data_df[meta_data_df['Sample'] == data.name][id_name].values
|
|
185
|
+
consolidate_s_cell_df = pd.concat([consolidate_s_cell_df, consolidate_s_cell_df_], axis=0)
|
|
186
|
+
|
|
187
|
+
consolidate_s_niche_df = consolidate_s_niche_df.set_index(id_name)
|
|
188
|
+
consolidate_s_niche_df = consolidate_s_niche_df.loc[meta_data_df[id_name], :]
|
|
189
|
+
consolidate_s_niche_df.to_csv(f'{output_dir}/niche_level_niche_cluster.csv.gz',
|
|
190
|
+
index=True,
|
|
191
|
+
index_label=id_name,
|
|
192
|
+
header=True)
|
|
193
|
+
consolidate_s_niche_df['Niche_Cluster'] = consolidate_s_niche_df.values.argmax(axis=1)
|
|
194
|
+
consolidate_s_niche_df['Niche_Cluster'].to_csv(f'{output_dir}/niche_level_max_niche_cluster.csv.gz',
|
|
195
|
+
index=True,
|
|
196
|
+
header=True)
|
|
197
|
+
consolidate_s_cell_df = consolidate_s_cell_df.set_index(id_name)
|
|
198
|
+
consolidate_s_cell_df = consolidate_s_cell_df.loc[meta_data_df[id_name], :]
|
|
199
|
+
consolidate_s_cell_df.to_csv(f'{output_dir}/cell_level_niche_cluster.csv.gz',
|
|
200
|
+
index=True,
|
|
201
|
+
index_label=id_name,
|
|
202
|
+
header=True)
|
|
203
|
+
consolidate_s_cell_df['Niche_Cluster'] = consolidate_s_cell_df.values.argmax(axis=1)
|
|
204
|
+
consolidate_s_cell_df['Niche_Cluster'].to_csv(f'{output_dir}/cell_level_max_niche_cluster.csv.gz',
|
|
205
|
+
index=True,
|
|
206
|
+
header=True)
|
ONTraC/GNN/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from ._GNN import *
|
ONTraC/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
1
|
+
from .version import __version__
|
|
2
|
+
|
|
3
|
+
citation = "Wang, W., Zheng, S., Shin, C. S. & Yuan, G. C. Characterizing Spatially Continuous Variations in Tissue Microenvironment through Niche Trajectory Analysis. https://www.biorxiv.org/content/10.1101/2024.04.23.590827v1. bioRxiv, 2024."
|
|
File without changes
|