flexynesis 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flexynesis/__init__.py +142 -0
- flexynesis/__main__.py +327 -0
- flexynesis/cli.py +426 -0
- flexynesis/config.py +43 -0
- flexynesis/data.py +986 -0
- flexynesis/feature_selection.py +246 -0
- flexynesis/main.py +475 -0
- flexynesis/models/__init__.py +7 -0
- flexynesis/models/crossmodal_pred.py +492 -0
- flexynesis/models/direct_pred.py +330 -0
- flexynesis/models/direct_pred_cnn.py +248 -0
- flexynesis/models/direct_pred_gcnn.py +360 -0
- flexynesis/models/supervised_vae.py +482 -0
- flexynesis/models/triplet_encoder.py +401 -0
- flexynesis/modules.py +270 -0
- flexynesis/utils.py +878 -0
- flexynesis-0.1.0.dist-info/LICENCE.md +404 -0
- flexynesis-0.1.0.dist-info/METADATA +286 -0
- flexynesis-0.1.0.dist-info/RECORD +22 -0
- flexynesis-0.1.0.dist-info/WHEEL +5 -0
- flexynesis-0.1.0.dist-info/entry_points.txt +3 -0
- flexynesis-0.1.0.dist-info/top_level.txt +1 -0
flexynesis/__init__.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
"""
|
|
2
|
+
# Flexynesis
|
|
3
|
+
|
|
4
|
+
Flexynesis is a deep-learning based multi-omics bulk sequencing data
|
|
5
|
+
integration suite with a focus on (pre-)clinical endpoint prediction. The
|
|
6
|
+
package includes multiple types of deep learning architectures such as simple
|
|
7
|
+
fully connected networks, supervised variational autoencoders; different
|
|
8
|
+
options of data layer fusion, and automates feature selection and
|
|
9
|
+
hyperparameter optimisation. The tools are continuosly benchmarked on publicly
|
|
10
|
+
available datasets mostly related to the study of cancer. Some of the
|
|
11
|
+
applications of the methods we develop are drug response modeling in cancer
|
|
12
|
+
patients or preclinical models (such as cell lines and patient-derived
|
|
13
|
+
xenografts), cancer subtype prediction, or any other clinically relevant
|
|
14
|
+
outcome prediction that can be formulated as a regression or classification
|
|
15
|
+
problem.
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
# Package Contents
|
|
19
|
+
- data: Pytorch Dataset classes and functions to import, process multiomics data.
|
|
20
|
+
- main: High-level functions for training, evaluating, and using models
|
|
21
|
+
- models:
|
|
22
|
+
- direct_pred: A multi-task fully connected neural network for direct prediction of one ore more target variables
|
|
23
|
+
- direct_pred_cnn: A multi-task one-dimensional convolutional neural network
|
|
24
|
+
- direct_pred_gcnn: A multi-task graph-convolutional neural network
|
|
25
|
+
- supervised_vae: A multi-task Supervised Variational Autoencoder model architecture
|
|
26
|
+
- triplet_encoder: A fully connected neural network implemented with a triplet loss-based contrastive learning
|
|
27
|
+
- modules: Reusable components for model architecture and training
|
|
28
|
+
- feature_selection: Feature selection methods including Laplacian scoring and redundancy filtering
|
|
29
|
+
- utils: General utility functions for data manipulation and visualization
|
|
30
|
+
- config: Default hyperparameter optimisation spaces
|
|
31
|
+
|
|
32
|
+
# Main Features
|
|
33
|
+
|
|
34
|
+
- Various multi-modal data fusion methods using different kinds of deep learning architectures.
|
|
35
|
+
- Data management tools for loading, preprocessing, and augmenting data.
|
|
36
|
+
- Feature selection methods for effective dimensionality reduction.
|
|
37
|
+
- Utility functions to facilitate data manipulation and visualization.
|
|
38
|
+
- High-level functions for training, evaluating, and using models.
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
# Benchmarks
|
|
42
|
+
|
|
43
|
+
For the latest benchmark results see:
|
|
44
|
+
https://bimsbstatic.mdc-berlin.de/akalin/buyar/flexynesis-benchmark-datasets/dashboard.html
|
|
45
|
+
|
|
46
|
+
The code for the benchmarking pipeline is at: https://github.com/BIMSBbioinfo/flexynesis-benchmarks
|
|
47
|
+
|
|
48
|
+
# Environment
|
|
49
|
+
|
|
50
|
+
To create a clone of the development environment, use the `spec-file.txt`:
|
|
51
|
+
```
|
|
52
|
+
conda create --name flexynesis --file spec-file.txt
|
|
53
|
+
conda activate flexynesis
|
|
54
|
+
```
|
|
55
|
+
|
|
56
|
+
To export existing spec-file.txt:
|
|
57
|
+
```
|
|
58
|
+
conda list --explicit > spec-file.txt
|
|
59
|
+
```
|
|
60
|
+
|
|
61
|
+
# Guix
|
|
62
|
+
|
|
63
|
+
You can also create a reproducible development environment with [GNU Guix](https://guix.gnu.org). You will need [this Guix commit](https://git.savannah.gnu.org/cgit/guix.git/commit/?id=e3e011a08141058598cc7631aeb52d620a3ccb8c) or later.
|
|
64
|
+
|
|
65
|
+
```
|
|
66
|
+
guix shell
|
|
67
|
+
```
|
|
68
|
+
|
|
69
|
+
or
|
|
70
|
+
|
|
71
|
+
```
|
|
72
|
+
guix shell -m manifest.scm
|
|
73
|
+
```
|
|
74
|
+
|
|
75
|
+
You can build a Guix package from the current committed state of your git checkout like this:
|
|
76
|
+
|
|
77
|
+
```
|
|
78
|
+
guix pack -f guix.scm
|
|
79
|
+
```
|
|
80
|
+
|
|
81
|
+
Do this to build a Docker image containing this package together with a matching Python installation:
|
|
82
|
+
|
|
83
|
+
```
|
|
84
|
+
guix pack -C none \
|
|
85
|
+
-e '(load "guix.scm")' \
|
|
86
|
+
-f docker \
|
|
87
|
+
-S /bin=bin -S /lib=lib -S /share=share \
|
|
88
|
+
glibc-locales coreutils bash python
|
|
89
|
+
```
|
|
90
|
+
|
|
91
|
+
# Installation
|
|
92
|
+
|
|
93
|
+
To install the project using setuptools, you can follow these steps:
|
|
94
|
+
|
|
95
|
+
1. Clone the project from the Git repository:
|
|
96
|
+
```
|
|
97
|
+
git clone git@github.com:BIMSBbioinfo/flexynesis.git
|
|
98
|
+
```
|
|
99
|
+
2. Navigate to the project directory:
|
|
100
|
+
```
|
|
101
|
+
cd flexynesis
|
|
102
|
+
```
|
|
103
|
+
3. Create a clone of the development environment, use the `spec-file.txt`:
|
|
104
|
+
```
|
|
105
|
+
conda create --name flexynesis --file spec-file.txt
|
|
106
|
+
conda activate flexynesis
|
|
107
|
+
```
|
|
108
|
+
4. Install the project:
|
|
109
|
+
```
|
|
110
|
+
pip install -e .
|
|
111
|
+
```
|
|
112
|
+
|
|
113
|
+
# Testing
|
|
114
|
+
|
|
115
|
+
Run unit tests
|
|
116
|
+
```python
|
|
117
|
+
pytest -vvv tests/unit
|
|
118
|
+
```
|
|
119
|
+
|
|
120
|
+
This will run all the unit tests in the tests directory.
|
|
121
|
+
|
|
122
|
+
# Contributing
|
|
123
|
+
If you would like to contribute to the project, please open an issue or a pull request on the GitHub repository.
|
|
124
|
+
|
|
125
|
+
# License
|
|
126
|
+
This package is currently private and is not meant to be used outside of Arcas.ai
|
|
127
|
+
|
|
128
|
+
# Authors
|
|
129
|
+
|
|
130
|
+
- Bora Uyar, bora.uyar@mdc-berlin.de
|
|
131
|
+
- Taras Savchyn, Taras.Savchyn@mdc-berlin.de
|
|
132
|
+
- Ricardo Wurmus, Ricardo.Wurmus@mdc-berlin.de
|
|
133
|
+
- Altuna Akalin, Altuna.Akalin@mdc-berlin.de
|
|
134
|
+
"""
|
|
135
|
+
|
|
136
|
+
from .modules import *
|
|
137
|
+
from .data import *
|
|
138
|
+
from .main import *
|
|
139
|
+
from .models import *
|
|
140
|
+
from .feature_selection import *
|
|
141
|
+
from .utils import *
|
|
142
|
+
from .config import *
|
flexynesis/__main__.py
ADDED
|
@@ -0,0 +1,327 @@
|
|
|
1
|
+
from lightning import seed_everything
|
|
2
|
+
# Set the seed for all the possible random number generators.
|
|
3
|
+
seed_everything(42, workers=True)
|
|
4
|
+
import lightning as pl
|
|
5
|
+
from typing import NamedTuple
|
|
6
|
+
import os, yaml, torch, time, random, warnings, argparse
|
|
7
|
+
os.environ["OMP_NUM_THREADS"] = "1"
|
|
8
|
+
import pandas as pd
|
|
9
|
+
import flexynesis
|
|
10
|
+
from flexynesis.models import *
|
|
11
|
+
from lightning.pytorch.callbacks import EarlyStopping
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def main():
|
|
15
|
+
parser = argparse.ArgumentParser(description="Flexynesis - Your PyTorch model training interface",
|
|
16
|
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
17
|
+
|
|
18
|
+
parser.add_argument("--data_path", help="(Required) Path to the folder with train/test data files", type=str, required = True)
|
|
19
|
+
parser.add_argument("--model_class", help="(Required) The kind of model class to instantiate", type=str,
|
|
20
|
+
choices=["DirectPred", "DirectPredGCNN", "supervised_vae", "MultiTripletNetwork", "CrossModalPred"], required = True)
|
|
21
|
+
parser.add_argument("--gnn_conv_type", help="If model_class is set to DirectPredGCNN, choose which graph convolution type to use", type=str,
|
|
22
|
+
choices=["GC", "GCN", "GAT", "SAGE"])
|
|
23
|
+
parser.add_argument("--target_variables",
|
|
24
|
+
help="(Optional if survival variables are not set to None)."
|
|
25
|
+
"Which variables in 'clin.csv' to use for predictions, comma-separated if multiple",
|
|
26
|
+
type = str, default = None)
|
|
27
|
+
parser.add_argument("--batch_variables",
|
|
28
|
+
help="(Optional) Which variables in 'clin.csv' to use for data integration / batch correction, comma-separated if multiple",
|
|
29
|
+
type = str, default = None)
|
|
30
|
+
parser.add_argument("--surv_event_var", help="Which column in 'clin.csv' to use as event/status indicator for survival modeling", type = str, default = None)
|
|
31
|
+
parser.add_argument("--surv_time_var", help="Which column in 'clin.csv' to use as time/duration indicator for survival modeling", type = str, default = None)
|
|
32
|
+
parser.add_argument('--config_path', type=str, default=None, help='Optional path to an external hyperparameter configuration file in YAML format.')
|
|
33
|
+
parser.add_argument("--fusion_type", help="How to fuse the omics layers", type=str, choices=["early", "intermediate"], default = 'intermediate')
|
|
34
|
+
parser.add_argument("--hpo_iter", help="Number of iterations for hyperparameter optimisation", type=int, default = 5)
|
|
35
|
+
parser.add_argument("--finetuning_samples", help="Number of samples from the test dataset to use for fine-tuning the model. Set to 0 to disable fine-tuning", type=int, default = 0)
|
|
36
|
+
parser.add_argument("--variance_threshold", help="Variance threshold (as percentile) to drop low variance features (default: 1; set to 0 for no variance filtering)", type=float, default = 1)
|
|
37
|
+
parser.add_argument("--correlation_threshold", help="Correlation threshold to drop highly redundant features (default: 0.8; set to 1 for no redundancy filtering)", type=float, default = 0.8)
|
|
38
|
+
parser.add_argument("--restrict_to_features", help="Restrict the analyis to the list of features provided by the user (default: None)", type = str, default = None)
|
|
39
|
+
parser.add_argument("--subsample", help="Downsample training set to randomly drawn N samples for training. Disabled when set to 0", type=int, default = 0)
|
|
40
|
+
parser.add_argument("--features_min", help="Minimum number of features to retain after feature selection", type=int, default = 500)
|
|
41
|
+
parser.add_argument("--features_top_percentile", help="Top percentile features (among the features remaining after variance filtering and data cleanup to retain after feature selection", type=float, default = 20)
|
|
42
|
+
parser.add_argument("--data_types", help="(Required) Which omic data matrices to work on, comma-separated: e.g. 'gex,cnv'", type=str, required = True)
|
|
43
|
+
parser.add_argument("--input_layers",
|
|
44
|
+
help="If model_class is set to CrossModalPred, choose which data types to use as input/encoded layers"
|
|
45
|
+
"Comma-separated if multiple",
|
|
46
|
+
type=str, default = None
|
|
47
|
+
)
|
|
48
|
+
parser.add_argument("--output_layers",
|
|
49
|
+
help="If model_class is set to CrossModalPred, choose which data types to use as output/decoded layers"
|
|
50
|
+
"Comma-separated if multiple",
|
|
51
|
+
type=str, default = None
|
|
52
|
+
)
|
|
53
|
+
parser.add_argument("--outdir", help="Path to the output folder to save the model outputs", type=str, default = os.getcwd())
|
|
54
|
+
parser.add_argument("--prefix", help="Job prefix to use for output files", type=str, default = 'job')
|
|
55
|
+
parser.add_argument("--log_transform", help="whether to apply log-transformation to input data matrices", type=str, choices=['True', 'False'], default = 'False')
|
|
56
|
+
parser.add_argument("--early_stop_patience", help="How many epochs to wait when no improvements in validation loss is observed (default: 10; set to -1 to disable early stopping)", type=int, default = 10)
|
|
57
|
+
parser.add_argument("--hpo_patience", help="How many hyperparamater optimisation iterations to wait for when no improvements are observed (default: 10; set to 0 to disable early stopping)", type=int, default = 10)
|
|
58
|
+
parser.add_argument("--use_cv", action="store_true",
|
|
59
|
+
help="(Optional) If set, the a 5-fold cross-validation training will be done. Otherwise, a single trainign on 80% of the dataset is done.")
|
|
60
|
+
parser.add_argument("--use_loss_weighting", help="whether to apply loss-balancing using uncertainty weights method", type=str, choices=['True', 'False'], default = 'True')
|
|
61
|
+
parser.add_argument("--evaluate_baseline_performance", help="whether to run Random Forest + SVMs to see the performance of off-the-shelf tools on the same dataset", type=str, choices=['True', 'False'], default = 'True')
|
|
62
|
+
parser.add_argument("--threads", help="(Optional) How many threads to use when using CPU (default: 4)", type=int, default = 4)
|
|
63
|
+
parser.add_argument("--use_gpu", action="store_true",
|
|
64
|
+
help="(Optional) If set, the system will attempt to use CUDA/GPU if available.")
|
|
65
|
+
# DirectPredGCNN args.
|
|
66
|
+
parser.add_argument("--graph", help="Graph to use, name of the database or path to the edge list on the disk.", type=str, default="STRING")
|
|
67
|
+
parser.add_argument("--string_organism", help="STRING DB organism id.", type=int, default=9606)
|
|
68
|
+
parser.add_argument("--string_node_name", help="Type of node name.", type=str, choices=["gene_name", "gene_id"], default="gene_name")
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
warnings.filterwarnings("ignore", ".*does not have many workers.*")
|
|
72
|
+
warnings.filterwarnings("ignore", "has been removed as a dependency of the")
|
|
73
|
+
warnings.filterwarnings("ignore", "The `srun` command is available on your system but is not used")
|
|
74
|
+
|
|
75
|
+
args = parser.parse_args()
|
|
76
|
+
|
|
77
|
+
# do some sanity checks on input arguments
|
|
78
|
+
# 1. Check for survival variables consistency
|
|
79
|
+
if (args.surv_event_var is None) != (args.surv_time_var is None):
|
|
80
|
+
parser.error("Both --surv_event_var and --surv_time_var must be provided together or left as None.")
|
|
81
|
+
|
|
82
|
+
# 2. Check for required variables for model classes
|
|
83
|
+
if args.model_class != "supervised_vae" and args.model_class != 'CrossModalPred':
|
|
84
|
+
if not any([args.target_variables, args.surv_event_var, args.batch_variables]):
|
|
85
|
+
parser.error(''.join(["When selecting a model other than 'supervised_vae' or 'CrossModalPred',",
|
|
86
|
+
"you must provide at least one of --target_variables, ",
|
|
87
|
+
"survival variables (--surv_event_var and --surv_time_var)",
|
|
88
|
+
"or --batch_variables."]))
|
|
89
|
+
|
|
90
|
+
# 3. Check for compatibility of fusion_type with DirectPredGCNN
|
|
91
|
+
if args.fusion_type == "early":
|
|
92
|
+
if args.model_class == "DirectPredGCNN":
|
|
93
|
+
parser.error("The 'DirectPredGCNN' model cannot be used with early fusion type. "
|
|
94
|
+
"Use --fusion_type intermediate instead.")
|
|
95
|
+
if args.model_class == 'CrossModalPred':
|
|
96
|
+
parser.error("The 'CrossModalPred' model cannot be used with early fusion type. "
|
|
97
|
+
"Use --fusion_type intermediate instead.")
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
# 4. Check for device availability if --accelerator is set.
|
|
101
|
+
if args.use_gpu:
|
|
102
|
+
if not torch.cuda.is_available():
|
|
103
|
+
warnings.warn(''.join(["\n\n!!! WARNING: GPU REQUESTED BUT NOT AVAILABLE. FALLING BACK TO CPU.\n",
|
|
104
|
+
"PERFORMANCE MAY BE DEGRADED, PARTICULARLY FOR DirectPredGCNN.\n",
|
|
105
|
+
"OTHER MODELS SHOULD HAVE REASONABLE PERFORMANCE ON CPU. \n",
|
|
106
|
+
"IF USING A SLURM SCHEDULER, ENSURE YOU REQUEST A GPU WITH: ",
|
|
107
|
+
"`srun --gpus=1 --pty flexynesis <rest of your_command>` !!!\n\n"]))
|
|
108
|
+
time.sleep(3) #wait a bit to capture user's attention to the warning
|
|
109
|
+
device_type = 'cpu'
|
|
110
|
+
torch.set_num_threads(args.threads)
|
|
111
|
+
else:
|
|
112
|
+
device_type = 'gpu'
|
|
113
|
+
else:
|
|
114
|
+
device_type = 'cpu'
|
|
115
|
+
torch.set_num_threads(args.threads)
|
|
116
|
+
|
|
117
|
+
# 5. check GNN arguments
|
|
118
|
+
if args.model_class == 'DirectPredGCNN':
|
|
119
|
+
if not args.gnn_conv_type:
|
|
120
|
+
warning_message = "\n".join([
|
|
121
|
+
"\n\n!!! When running DirectPredGCNN, a convolution type can be set",
|
|
122
|
+
"with the --gnn_conv_type flag. See `flexynesis -h` for full set of options.",
|
|
123
|
+
"Falling back on the default convolution type: GC !!!\n\n"
|
|
124
|
+
])
|
|
125
|
+
warnings.warn(warning_message)
|
|
126
|
+
time.sleep(3) #wait a bit to capture user's attention to the warning
|
|
127
|
+
gnn_conv_type = 'GC'
|
|
128
|
+
else:
|
|
129
|
+
gnn_conv_type = args.gnn_conv_type
|
|
130
|
+
else:
|
|
131
|
+
gnn_conv_type = None
|
|
132
|
+
|
|
133
|
+
# 6. Check CrossModalPred arguments
|
|
134
|
+
input_layers = args.input_layers
|
|
135
|
+
output_layers = args.output_layers
|
|
136
|
+
datatypes = args.data_types.strip().split(',')
|
|
137
|
+
if args.model_class == 'CrossModalPred':
|
|
138
|
+
# check if input output layers are matching the requested data types
|
|
139
|
+
if args.input_layers:
|
|
140
|
+
input_layers = input_layers.strip().split(',')
|
|
141
|
+
# Check if input_layers are a subset of datatypes
|
|
142
|
+
if not all(layer in datatypes for layer in input_layers):
|
|
143
|
+
raise ValueError(f"Input layers {input_layers} are not a valid subset of the data types: ({datatypes}).")
|
|
144
|
+
# check if output_layers are a subset of datatypes
|
|
145
|
+
if args.output_layers:
|
|
146
|
+
output_layers = output_layers.strip().split(',')
|
|
147
|
+
if not all(layer in datatypes for layer in output_layers):
|
|
148
|
+
raise ValueError(f"Output layers {output_layers} are not a valid subset of the data types: ({datatypes}).")
|
|
149
|
+
|
|
150
|
+
# Validate paths
|
|
151
|
+
if not os.path.exists(args.data_path):
|
|
152
|
+
raise FileNotFoundError(f"Input --data_path doesn't exist at:", {args.data_path})
|
|
153
|
+
if not os.path.exists(args.outdir):
|
|
154
|
+
raise FileNotFoundError(f"Path to --outdir doesn't exist at:", {args.outdir})
|
|
155
|
+
|
|
156
|
+
class AvailableModels(NamedTuple):
|
|
157
|
+
# type AvailableModel = ModelClass: Type, ModelConfig: str
|
|
158
|
+
DirectPred: tuple[DirectPred, str] = DirectPred, "DirectPred"
|
|
159
|
+
supervised_vae: tuple[supervised_vae, str] = supervised_vae, "supervised_vae"
|
|
160
|
+
MultiTripletNetwork: tuple[MultiTripletNetwork, str] = MultiTripletNetwork, "MultiTripletNetwork"
|
|
161
|
+
DirectPredGCNN: tuple[DirectPredGCNN, str] = DirectPredGCNN, "DirectPredGCNN"
|
|
162
|
+
CrossModalPred: tuple[CrossModalPred, str] = CrossModalPred, "CrossModalPred"
|
|
163
|
+
|
|
164
|
+
available_models = AvailableModels()
|
|
165
|
+
model_class = getattr(available_models, args.model_class, None)
|
|
166
|
+
if model_class is None:
|
|
167
|
+
raise ValueError(f"Invalid model_class: {args.model_class}")
|
|
168
|
+
else:
|
|
169
|
+
model_class, config_name = model_class
|
|
170
|
+
|
|
171
|
+
# Fix graph argument for non GNN models.
|
|
172
|
+
graph = args.graph if config_name == "DirectPredGCNN" else None
|
|
173
|
+
|
|
174
|
+
# import assays and labels
|
|
175
|
+
inputDir = args.data_path
|
|
176
|
+
|
|
177
|
+
# Set concatenate to True to use early fusion, otherwise it will run intermediate fusion
|
|
178
|
+
concatenate = False
|
|
179
|
+
if args.fusion_type == 'early':
|
|
180
|
+
concatenate = True
|
|
181
|
+
|
|
182
|
+
data_importer = flexynesis.DataImporter(path = args.data_path,
|
|
183
|
+
data_types = datatypes,
|
|
184
|
+
concatenate = concatenate,
|
|
185
|
+
log_transform = args.log_transform == 'True',
|
|
186
|
+
variance_threshold = args.variance_threshold/100,
|
|
187
|
+
correlation_threshold = args.correlation_threshold,
|
|
188
|
+
restrict_to_features = args.restrict_to_features,
|
|
189
|
+
min_features= args.features_min,
|
|
190
|
+
top_percentile= args.features_top_percentile,
|
|
191
|
+
graph=graph,
|
|
192
|
+
processed_dir = '_'.join(['processed', args.prefix]),
|
|
193
|
+
string_organism=args.string_organism,
|
|
194
|
+
string_node_name=args.string_node_name,
|
|
195
|
+
downsample = args.subsample)
|
|
196
|
+
train_dataset, test_dataset = data_importer.import_data(force = True)
|
|
197
|
+
|
|
198
|
+
# print feature logs to file (we use these tables to track which features are dropped/selected and why)
|
|
199
|
+
feature_logs = data_importer.feature_logs
|
|
200
|
+
for key in feature_logs.keys():
|
|
201
|
+
feature_logs[key].to_csv(os.path.join(args.outdir, '.'.join([args.prefix, 'feature_logs', key, 'csv'])),
|
|
202
|
+
header=True, index=False)
|
|
203
|
+
|
|
204
|
+
# define a tuner object, which will instantiate a DirectPred class
|
|
205
|
+
# using the input dataset and the tuning configuration from the config.py
|
|
206
|
+
tuner = flexynesis.HyperparameterTuning(dataset = train_dataset,
|
|
207
|
+
model_class = model_class,
|
|
208
|
+
target_variables = args.target_variables.strip().split(',') if args.target_variables is not None else [],
|
|
209
|
+
batch_variables = args.batch_variables.strip().split(',') if args.batch_variables is not None else None,
|
|
210
|
+
surv_event_var = args.surv_event_var,
|
|
211
|
+
surv_time_var = args.surv_time_var,
|
|
212
|
+
config_name = config_name,
|
|
213
|
+
config_path = args.config_path,
|
|
214
|
+
n_iter=int(args.hpo_iter),
|
|
215
|
+
use_loss_weighting = args.use_loss_weighting == 'True',
|
|
216
|
+
use_cv = args.use_cv,
|
|
217
|
+
early_stop_patience = int(args.early_stop_patience),
|
|
218
|
+
device_type = device_type,
|
|
219
|
+
gnn_conv_type = gnn_conv_type,
|
|
220
|
+
input_layers = input_layers,
|
|
221
|
+
output_layers = output_layers)
|
|
222
|
+
|
|
223
|
+
# do a hyperparameter search training multiple models and get the best_configuration
|
|
224
|
+
model, best_params = tuner.perform_tuning(hpo_patience = args.hpo_patience)
|
|
225
|
+
|
|
226
|
+
# if fine-tuning is enabled; fine tune the model on a portion of test samples
|
|
227
|
+
if args.finetuning_samples > 0:
|
|
228
|
+
finetuneSampleN = args.finetuning_samples
|
|
229
|
+
print("[INFO] Finetuning the model on ",finetuneSampleN,"test samples")
|
|
230
|
+
# split test dataset into finetuning and holdout datasets
|
|
231
|
+
all_indices = range(len(test_dataset))
|
|
232
|
+
finetune_indices = random.sample(all_indices, finetuneSampleN)
|
|
233
|
+
holdout_indices = list(set(all_indices) - set(finetune_indices))
|
|
234
|
+
finetune_dataset = test_dataset.subset(finetune_indices)
|
|
235
|
+
holdout_dataset = test_dataset.subset(holdout_indices)
|
|
236
|
+
|
|
237
|
+
# fine tune on the finetuning dataset; freeze the encoders
|
|
238
|
+
finetuner = flexynesis.FineTuner(model,
|
|
239
|
+
finetune_dataset)
|
|
240
|
+
finetuner.run_experiments()
|
|
241
|
+
|
|
242
|
+
# update the model to finetuned model
|
|
243
|
+
model = finetuner.model
|
|
244
|
+
# update the test dataset to exclude finetuning samples
|
|
245
|
+
test_dataset = holdout_dataset
|
|
246
|
+
|
|
247
|
+
# evaluate predictions; (if any supervised learning happened)
|
|
248
|
+
if any([args.target_variables, args.surv_event_var, args.batch_variables]):
|
|
249
|
+
print("[INFO] Computing model evaluation metrics")
|
|
250
|
+
metrics_df = flexynesis.evaluate_wrapper(model.predict(test_dataset), test_dataset,
|
|
251
|
+
surv_event_var=model.surv_event_var,
|
|
252
|
+
surv_time_var=model.surv_time_var)
|
|
253
|
+
metrics_df.to_csv(os.path.join(args.outdir, '.'.join([args.prefix, 'stats.csv'])), header=True, index=False)
|
|
254
|
+
|
|
255
|
+
# print known/predicted labels
|
|
256
|
+
predicted_labels = pd.concat([flexynesis.get_predicted_labels(model.predict(train_dataset), train_dataset, 'train'),
|
|
257
|
+
flexynesis.get_predicted_labels(model.predict(test_dataset), test_dataset, 'test')],
|
|
258
|
+
ignore_index=True)
|
|
259
|
+
predicted_labels.to_csv(os.path.join(args.outdir, '.'.join([args.prefix, 'predicted_labels.csv'])), header=True, index=False)
|
|
260
|
+
# compute feature importance values
|
|
261
|
+
print("[INFO] Computing variable importance scores")
|
|
262
|
+
for var in model.target_variables:
|
|
263
|
+
model.compute_feature_importance(train_dataset, var, steps = 50)
|
|
264
|
+
df_imp = pd.concat([model.feature_importances[x] for x in model.target_variables],
|
|
265
|
+
ignore_index = True)
|
|
266
|
+
df_imp.to_csv(os.path.join(args.outdir, '.'.join([args.prefix, 'feature_importance.csv'])), header=True, index=False)
|
|
267
|
+
|
|
268
|
+
# get sample embeddings and save
|
|
269
|
+
print("[INFO] Extracting sample embeddings")
|
|
270
|
+
embeddings_train = model.transform(train_dataset)
|
|
271
|
+
embeddings_test = model.transform(test_dataset)
|
|
272
|
+
|
|
273
|
+
embeddings_train.to_csv(os.path.join(args.outdir, '.'.join([args.prefix, 'embeddings_train.csv'])), header=True)
|
|
274
|
+
embeddings_test.to_csv(os.path.join(args.outdir, '.'.join([args.prefix, 'embeddings_test.csv'])), header=True)
|
|
275
|
+
|
|
276
|
+
# also filter embeddings to remove batch-associated dims and only keep target-variable associated dims
|
|
277
|
+
if args.batch_variables is not None:
|
|
278
|
+
print("[INFO] Printing filtered embeddings")
|
|
279
|
+
embeddings_train_filtered = flexynesis.remove_batch_associated_variables(data = embeddings_train,
|
|
280
|
+
batch_dict={x: train_dataset.ann[x] for x in model.batch_variables} if model.batch_variables is not None else None,
|
|
281
|
+
target_dict={x: train_dataset.ann[x] for x in model.target_variables},
|
|
282
|
+
variable_types=train_dataset.variable_types)
|
|
283
|
+
# filter test embeddings to keep the same dims as the filtered training embeddings
|
|
284
|
+
embeddings_test_filtered = embeddings_test[embeddings_train_filtered.columns]
|
|
285
|
+
|
|
286
|
+
# save
|
|
287
|
+
embeddings_train_filtered.to_csv(os.path.join(args.outdir, '.'.join([args.prefix, 'embeddings_train.filtered.csv'])), header=True)
|
|
288
|
+
embeddings_test_filtered.to_csv(os.path.join(args.outdir, '.'.join([args.prefix, 'embeddings_test.filtered.csv'])), header=True)
|
|
289
|
+
|
|
290
|
+
# for architectures with decoders; print decoded output layers
|
|
291
|
+
if args.model_class == 'CrossModalPred':
|
|
292
|
+
print("[INFO] Printing decoded output layers")
|
|
293
|
+
output_layers_train = model.decode(train_dataset)
|
|
294
|
+
output_layers_test = model.decode(test_dataset)
|
|
295
|
+
for layer in output_layers_train.keys():
|
|
296
|
+
output_layers_train[layer].to_csv(os.path.join(args.outdir, '.'.join([args.prefix, 'train_decoded', layer, 'csv'])), header=True)
|
|
297
|
+
for layer in output_layers_test.keys():
|
|
298
|
+
output_layers_test[layer].to_csv(os.path.join(args.outdir, '.'.join([args.prefix, 'test_decoded', layer, 'csv'])), header=True)
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
# evaluate off-the-shelf methods on the main target variable
|
|
302
|
+
if args.evaluate_baseline_performance == 'True':
|
|
303
|
+
print("[INFO] Computing off-the-shelf method performance on first target variable:",model.target_variables[0])
|
|
304
|
+
var = model.target_variables[0]
|
|
305
|
+
metrics = pd.DataFrame()
|
|
306
|
+
if var != model.surv_event_var:
|
|
307
|
+
metrics = flexynesis.evaluate_baseline_performance(train_dataset, test_dataset,
|
|
308
|
+
variable_name = var,
|
|
309
|
+
n_folds=5,
|
|
310
|
+
n_jobs = int(args.threads))
|
|
311
|
+
if model.surv_event_var and model.surv_time_var:
|
|
312
|
+
print("[INFO] Computing off-the-shelf method performance on survival variable:",model.surv_time_var)
|
|
313
|
+
metrics_baseline_survival = flexynesis.evaluate_baseline_survival_performance(train_dataset, test_dataset,
|
|
314
|
+
model.surv_time_var,
|
|
315
|
+
model.surv_event_var,
|
|
316
|
+
n_folds = 5,
|
|
317
|
+
n_jobs = int(args.threads))
|
|
318
|
+
metrics = pd.concat([metrics, metrics_baseline_survival], axis = 0, ignore_index = True)
|
|
319
|
+
|
|
320
|
+
if not metrics.empty:
|
|
321
|
+
metrics.to_csv(os.path.join(args.outdir, '.'.join([args.prefix, 'baseline.stats.csv'])), header=True, index=False)
|
|
322
|
+
|
|
323
|
+
# save the trained model in file
|
|
324
|
+
torch.save(model, os.path.join(args.outdir, '.'.join([args.prefix, 'final_model.pth'])))
|
|
325
|
+
|
|
326
|
+
if __name__ == "__main__":
|
|
327
|
+
main()
|