flexynesis 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
flexynesis/__init__.py ADDED
@@ -0,0 +1,142 @@
1
+ """
2
+ # Flexynesis
3
+
4
+ Flexynesis is a deep-learning based multi-omics bulk sequencing data
5
+ integration suite with a focus on (pre-)clinical endpoint prediction. The
6
+ package includes multiple types of deep learning architectures such as simple
7
+ fully connected networks, supervised variational autoencoders; different
8
+ options of data layer fusion, and automates feature selection and
9
+ hyperparameter optimisation. The tools are continuosly benchmarked on publicly
10
+ available datasets mostly related to the study of cancer. Some of the
11
+ applications of the methods we develop are drug response modeling in cancer
12
+ patients or preclinical models (such as cell lines and patient-derived
13
+ xenografts), cancer subtype prediction, or any other clinically relevant
14
+ outcome prediction that can be formulated as a regression or classification
15
+ problem.
16
+
17
+
18
+ # Package Contents
19
+ - data: Pytorch Dataset classes and functions to import, process multiomics data.
20
+ - main: High-level functions for training, evaluating, and using models
21
+ - models:
22
+ - direct_pred: A multi-task fully connected neural network for direct prediction of one ore more target variables
23
+ - direct_pred_cnn: A multi-task one-dimensional convolutional neural network
24
+ - direct_pred_gcnn: A multi-task graph-convolutional neural network
25
+ - supervised_vae: A multi-task Supervised Variational Autoencoder model architecture
26
+ - triplet_encoder: A fully connected neural network implemented with a triplet loss-based contrastive learning
27
+ - modules: Reusable components for model architecture and training
28
+ - feature_selection: Feature selection methods including Laplacian scoring and redundancy filtering
29
+ - utils: General utility functions for data manipulation and visualization
30
+ - config: Default hyperparameter optimisation spaces
31
+
32
+ # Main Features
33
+
34
+ - Various multi-modal data fusion methods using different kinds of deep learning architectures.
35
+ - Data management tools for loading, preprocessing, and augmenting data.
36
+ - Feature selection methods for effective dimensionality reduction.
37
+ - Utility functions to facilitate data manipulation and visualization.
38
+ - High-level functions for training, evaluating, and using models.
39
+
40
+
41
+ # Benchmarks
42
+
43
+ For the latest benchmark results see:
44
+ https://bimsbstatic.mdc-berlin.de/akalin/buyar/flexynesis-benchmark-datasets/dashboard.html
45
+
46
+ The code for the benchmarking pipeline is at: https://github.com/BIMSBbioinfo/flexynesis-benchmarks
47
+
48
+ # Environment
49
+
50
+ To create a clone of the development environment, use the `spec-file.txt`:
51
+ ```
52
+ conda create --name flexynesis --file spec-file.txt
53
+ conda activate flexynesis
54
+ ```
55
+
56
+ To export existing spec-file.txt:
57
+ ```
58
+ conda list --explicit > spec-file.txt
59
+ ```
60
+
61
+ # Guix
62
+
63
+ You can also create a reproducible development environment with [GNU Guix](https://guix.gnu.org). You will need [this Guix commit](https://git.savannah.gnu.org/cgit/guix.git/commit/?id=e3e011a08141058598cc7631aeb52d620a3ccb8c) or later.
64
+
65
+ ```
66
+ guix shell
67
+ ```
68
+
69
+ or
70
+
71
+ ```
72
+ guix shell -m manifest.scm
73
+ ```
74
+
75
+ You can build a Guix package from the current committed state of your git checkout like this:
76
+
77
+ ```
78
+ guix pack -f guix.scm
79
+ ```
80
+
81
+ Do this to build a Docker image containing this package together with a matching Python installation:
82
+
83
+ ```
84
+ guix pack -C none \
85
+ -e '(load "guix.scm")' \
86
+ -f docker \
87
+ -S /bin=bin -S /lib=lib -S /share=share \
88
+ glibc-locales coreutils bash python
89
+ ```
90
+
91
+ # Installation
92
+
93
+ To install the project using setuptools, you can follow these steps:
94
+
95
+ 1. Clone the project from the Git repository:
96
+ ```
97
+ git clone git@github.com:BIMSBbioinfo/flexynesis.git
98
+ ```
99
+ 2. Navigate to the project directory:
100
+ ```
101
+ cd flexynesis
102
+ ```
103
+ 3. Create a clone of the development environment, use the `spec-file.txt`:
104
+ ```
105
+ conda create --name flexynesis --file spec-file.txt
106
+ conda activate flexynesis
107
+ ```
108
+ 4. Install the project:
109
+ ```
110
+ pip install -e .
111
+ ```
112
+
113
+ # Testing
114
+
115
+ Run unit tests
116
+ ```python
117
+ pytest -vvv tests/unit
118
+ ```
119
+
120
+ This will run all the unit tests in the tests directory.
121
+
122
+ # Contributing
123
+ If you would like to contribute to the project, please open an issue or a pull request on the GitHub repository.
124
+
125
+ # License
126
+ This package is currently private and is not meant to be used outside of Arcas.ai
127
+
128
+ # Authors
129
+
130
+ - Bora Uyar, bora.uyar@mdc-berlin.de
131
+ - Taras Savchyn, Taras.Savchyn@mdc-berlin.de
132
+ - Ricardo Wurmus, Ricardo.Wurmus@mdc-berlin.de
133
+ - Altuna Akalin, Altuna.Akalin@mdc-berlin.de
134
+ """
135
+
136
+ from .modules import *
137
+ from .data import *
138
+ from .main import *
139
+ from .models import *
140
+ from .feature_selection import *
141
+ from .utils import *
142
+ from .config import *
flexynesis/__main__.py ADDED
@@ -0,0 +1,327 @@
1
+ from lightning import seed_everything
2
+ # Set the seed for all the possible random number generators.
3
+ seed_everything(42, workers=True)
4
+ import lightning as pl
5
+ from typing import NamedTuple
6
+ import os, yaml, torch, time, random, warnings, argparse
7
+ os.environ["OMP_NUM_THREADS"] = "1"
8
+ import pandas as pd
9
+ import flexynesis
10
+ from flexynesis.models import *
11
+ from lightning.pytorch.callbacks import EarlyStopping
12
+
13
+
14
+ def main():
15
+ parser = argparse.ArgumentParser(description="Flexynesis - Your PyTorch model training interface",
16
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
17
+
18
+ parser.add_argument("--data_path", help="(Required) Path to the folder with train/test data files", type=str, required = True)
19
+ parser.add_argument("--model_class", help="(Required) The kind of model class to instantiate", type=str,
20
+ choices=["DirectPred", "DirectPredGCNN", "supervised_vae", "MultiTripletNetwork", "CrossModalPred"], required = True)
21
+ parser.add_argument("--gnn_conv_type", help="If model_class is set to DirectPredGCNN, choose which graph convolution type to use", type=str,
22
+ choices=["GC", "GCN", "GAT", "SAGE"])
23
+ parser.add_argument("--target_variables",
24
+ help="(Optional if survival variables are not set to None)."
25
+ "Which variables in 'clin.csv' to use for predictions, comma-separated if multiple",
26
+ type = str, default = None)
27
+ parser.add_argument("--batch_variables",
28
+ help="(Optional) Which variables in 'clin.csv' to use for data integration / batch correction, comma-separated if multiple",
29
+ type = str, default = None)
30
+ parser.add_argument("--surv_event_var", help="Which column in 'clin.csv' to use as event/status indicator for survival modeling", type = str, default = None)
31
+ parser.add_argument("--surv_time_var", help="Which column in 'clin.csv' to use as time/duration indicator for survival modeling", type = str, default = None)
32
+ parser.add_argument('--config_path', type=str, default=None, help='Optional path to an external hyperparameter configuration file in YAML format.')
33
+ parser.add_argument("--fusion_type", help="How to fuse the omics layers", type=str, choices=["early", "intermediate"], default = 'intermediate')
34
+ parser.add_argument("--hpo_iter", help="Number of iterations for hyperparameter optimisation", type=int, default = 5)
35
+ parser.add_argument("--finetuning_samples", help="Number of samples from the test dataset to use for fine-tuning the model. Set to 0 to disable fine-tuning", type=int, default = 0)
36
+ parser.add_argument("--variance_threshold", help="Variance threshold (as percentile) to drop low variance features (default: 1; set to 0 for no variance filtering)", type=float, default = 1)
37
+ parser.add_argument("--correlation_threshold", help="Correlation threshold to drop highly redundant features (default: 0.8; set to 1 for no redundancy filtering)", type=float, default = 0.8)
38
+ parser.add_argument("--restrict_to_features", help="Restrict the analyis to the list of features provided by the user (default: None)", type = str, default = None)
39
+ parser.add_argument("--subsample", help="Downsample training set to randomly drawn N samples for training. Disabled when set to 0", type=int, default = 0)
40
+ parser.add_argument("--features_min", help="Minimum number of features to retain after feature selection", type=int, default = 500)
41
+ parser.add_argument("--features_top_percentile", help="Top percentile features (among the features remaining after variance filtering and data cleanup to retain after feature selection", type=float, default = 20)
42
+ parser.add_argument("--data_types", help="(Required) Which omic data matrices to work on, comma-separated: e.g. 'gex,cnv'", type=str, required = True)
43
+ parser.add_argument("--input_layers",
44
+ help="If model_class is set to CrossModalPred, choose which data types to use as input/encoded layers"
45
+ "Comma-separated if multiple",
46
+ type=str, default = None
47
+ )
48
+ parser.add_argument("--output_layers",
49
+ help="If model_class is set to CrossModalPred, choose which data types to use as output/decoded layers"
50
+ "Comma-separated if multiple",
51
+ type=str, default = None
52
+ )
53
+ parser.add_argument("--outdir", help="Path to the output folder to save the model outputs", type=str, default = os.getcwd())
54
+ parser.add_argument("--prefix", help="Job prefix to use for output files", type=str, default = 'job')
55
+ parser.add_argument("--log_transform", help="whether to apply log-transformation to input data matrices", type=str, choices=['True', 'False'], default = 'False')
56
+ parser.add_argument("--early_stop_patience", help="How many epochs to wait when no improvements in validation loss is observed (default: 10; set to -1 to disable early stopping)", type=int, default = 10)
57
+ parser.add_argument("--hpo_patience", help="How many hyperparamater optimisation iterations to wait for when no improvements are observed (default: 10; set to 0 to disable early stopping)", type=int, default = 10)
58
+ parser.add_argument("--use_cv", action="store_true",
59
+ help="(Optional) If set, the a 5-fold cross-validation training will be done. Otherwise, a single trainign on 80% of the dataset is done.")
60
+ parser.add_argument("--use_loss_weighting", help="whether to apply loss-balancing using uncertainty weights method", type=str, choices=['True', 'False'], default = 'True')
61
+ parser.add_argument("--evaluate_baseline_performance", help="whether to run Random Forest + SVMs to see the performance of off-the-shelf tools on the same dataset", type=str, choices=['True', 'False'], default = 'True')
62
+ parser.add_argument("--threads", help="(Optional) How many threads to use when using CPU (default: 4)", type=int, default = 4)
63
+ parser.add_argument("--use_gpu", action="store_true",
64
+ help="(Optional) If set, the system will attempt to use CUDA/GPU if available.")
65
+ # DirectPredGCNN args.
66
+ parser.add_argument("--graph", help="Graph to use, name of the database or path to the edge list on the disk.", type=str, default="STRING")
67
+ parser.add_argument("--string_organism", help="STRING DB organism id.", type=int, default=9606)
68
+ parser.add_argument("--string_node_name", help="Type of node name.", type=str, choices=["gene_name", "gene_id"], default="gene_name")
69
+
70
+
71
+ warnings.filterwarnings("ignore", ".*does not have many workers.*")
72
+ warnings.filterwarnings("ignore", "has been removed as a dependency of the")
73
+ warnings.filterwarnings("ignore", "The `srun` command is available on your system but is not used")
74
+
75
+ args = parser.parse_args()
76
+
77
+ # do some sanity checks on input arguments
78
+ # 1. Check for survival variables consistency
79
+ if (args.surv_event_var is None) != (args.surv_time_var is None):
80
+ parser.error("Both --surv_event_var and --surv_time_var must be provided together or left as None.")
81
+
82
+ # 2. Check for required variables for model classes
83
+ if args.model_class != "supervised_vae" and args.model_class != 'CrossModalPred':
84
+ if not any([args.target_variables, args.surv_event_var, args.batch_variables]):
85
+ parser.error(''.join(["When selecting a model other than 'supervised_vae' or 'CrossModalPred',",
86
+ "you must provide at least one of --target_variables, ",
87
+ "survival variables (--surv_event_var and --surv_time_var)",
88
+ "or --batch_variables."]))
89
+
90
+ # 3. Check for compatibility of fusion_type with DirectPredGCNN
91
+ if args.fusion_type == "early":
92
+ if args.model_class == "DirectPredGCNN":
93
+ parser.error("The 'DirectPredGCNN' model cannot be used with early fusion type. "
94
+ "Use --fusion_type intermediate instead.")
95
+ if args.model_class == 'CrossModalPred':
96
+ parser.error("The 'CrossModalPred' model cannot be used with early fusion type. "
97
+ "Use --fusion_type intermediate instead.")
98
+
99
+
100
+ # 4. Check for device availability if --accelerator is set.
101
+ if args.use_gpu:
102
+ if not torch.cuda.is_available():
103
+ warnings.warn(''.join(["\n\n!!! WARNING: GPU REQUESTED BUT NOT AVAILABLE. FALLING BACK TO CPU.\n",
104
+ "PERFORMANCE MAY BE DEGRADED, PARTICULARLY FOR DirectPredGCNN.\n",
105
+ "OTHER MODELS SHOULD HAVE REASONABLE PERFORMANCE ON CPU. \n",
106
+ "IF USING A SLURM SCHEDULER, ENSURE YOU REQUEST A GPU WITH: ",
107
+ "`srun --gpus=1 --pty flexynesis <rest of your_command>` !!!\n\n"]))
108
+ time.sleep(3) #wait a bit to capture user's attention to the warning
109
+ device_type = 'cpu'
110
+ torch.set_num_threads(args.threads)
111
+ else:
112
+ device_type = 'gpu'
113
+ else:
114
+ device_type = 'cpu'
115
+ torch.set_num_threads(args.threads)
116
+
117
+ # 5. check GNN arguments
118
+ if args.model_class == 'DirectPredGCNN':
119
+ if not args.gnn_conv_type:
120
+ warning_message = "\n".join([
121
+ "\n\n!!! When running DirectPredGCNN, a convolution type can be set",
122
+ "with the --gnn_conv_type flag. See `flexynesis -h` for full set of options.",
123
+ "Falling back on the default convolution type: GC !!!\n\n"
124
+ ])
125
+ warnings.warn(warning_message)
126
+ time.sleep(3) #wait a bit to capture user's attention to the warning
127
+ gnn_conv_type = 'GC'
128
+ else:
129
+ gnn_conv_type = args.gnn_conv_type
130
+ else:
131
+ gnn_conv_type = None
132
+
133
+ # 6. Check CrossModalPred arguments
134
+ input_layers = args.input_layers
135
+ output_layers = args.output_layers
136
+ datatypes = args.data_types.strip().split(',')
137
+ if args.model_class == 'CrossModalPred':
138
+ # check if input output layers are matching the requested data types
139
+ if args.input_layers:
140
+ input_layers = input_layers.strip().split(',')
141
+ # Check if input_layers are a subset of datatypes
142
+ if not all(layer in datatypes for layer in input_layers):
143
+ raise ValueError(f"Input layers {input_layers} are not a valid subset of the data types: ({datatypes}).")
144
+ # check if output_layers are a subset of datatypes
145
+ if args.output_layers:
146
+ output_layers = output_layers.strip().split(',')
147
+ if not all(layer in datatypes for layer in output_layers):
148
+ raise ValueError(f"Output layers {output_layers} are not a valid subset of the data types: ({datatypes}).")
149
+
150
+ # Validate paths
151
+ if not os.path.exists(args.data_path):
152
+ raise FileNotFoundError(f"Input --data_path doesn't exist at:", {args.data_path})
153
+ if not os.path.exists(args.outdir):
154
+ raise FileNotFoundError(f"Path to --outdir doesn't exist at:", {args.outdir})
155
+
156
+ class AvailableModels(NamedTuple):
157
+ # type AvailableModel = ModelClass: Type, ModelConfig: str
158
+ DirectPred: tuple[DirectPred, str] = DirectPred, "DirectPred"
159
+ supervised_vae: tuple[supervised_vae, str] = supervised_vae, "supervised_vae"
160
+ MultiTripletNetwork: tuple[MultiTripletNetwork, str] = MultiTripletNetwork, "MultiTripletNetwork"
161
+ DirectPredGCNN: tuple[DirectPredGCNN, str] = DirectPredGCNN, "DirectPredGCNN"
162
+ CrossModalPred: tuple[CrossModalPred, str] = CrossModalPred, "CrossModalPred"
163
+
164
+ available_models = AvailableModels()
165
+ model_class = getattr(available_models, args.model_class, None)
166
+ if model_class is None:
167
+ raise ValueError(f"Invalid model_class: {args.model_class}")
168
+ else:
169
+ model_class, config_name = model_class
170
+
171
+ # Fix graph argument for non GNN models.
172
+ graph = args.graph if config_name == "DirectPredGCNN" else None
173
+
174
+ # import assays and labels
175
+ inputDir = args.data_path
176
+
177
+ # Set concatenate to True to use early fusion, otherwise it will run intermediate fusion
178
+ concatenate = False
179
+ if args.fusion_type == 'early':
180
+ concatenate = True
181
+
182
+ data_importer = flexynesis.DataImporter(path = args.data_path,
183
+ data_types = datatypes,
184
+ concatenate = concatenate,
185
+ log_transform = args.log_transform == 'True',
186
+ variance_threshold = args.variance_threshold/100,
187
+ correlation_threshold = args.correlation_threshold,
188
+ restrict_to_features = args.restrict_to_features,
189
+ min_features= args.features_min,
190
+ top_percentile= args.features_top_percentile,
191
+ graph=graph,
192
+ processed_dir = '_'.join(['processed', args.prefix]),
193
+ string_organism=args.string_organism,
194
+ string_node_name=args.string_node_name,
195
+ downsample = args.subsample)
196
+ train_dataset, test_dataset = data_importer.import_data(force = True)
197
+
198
+ # print feature logs to file (we use these tables to track which features are dropped/selected and why)
199
+ feature_logs = data_importer.feature_logs
200
+ for key in feature_logs.keys():
201
+ feature_logs[key].to_csv(os.path.join(args.outdir, '.'.join([args.prefix, 'feature_logs', key, 'csv'])),
202
+ header=True, index=False)
203
+
204
+ # define a tuner object, which will instantiate a DirectPred class
205
+ # using the input dataset and the tuning configuration from the config.py
206
+ tuner = flexynesis.HyperparameterTuning(dataset = train_dataset,
207
+ model_class = model_class,
208
+ target_variables = args.target_variables.strip().split(',') if args.target_variables is not None else [],
209
+ batch_variables = args.batch_variables.strip().split(',') if args.batch_variables is not None else None,
210
+ surv_event_var = args.surv_event_var,
211
+ surv_time_var = args.surv_time_var,
212
+ config_name = config_name,
213
+ config_path = args.config_path,
214
+ n_iter=int(args.hpo_iter),
215
+ use_loss_weighting = args.use_loss_weighting == 'True',
216
+ use_cv = args.use_cv,
217
+ early_stop_patience = int(args.early_stop_patience),
218
+ device_type = device_type,
219
+ gnn_conv_type = gnn_conv_type,
220
+ input_layers = input_layers,
221
+ output_layers = output_layers)
222
+
223
+ # do a hyperparameter search training multiple models and get the best_configuration
224
+ model, best_params = tuner.perform_tuning(hpo_patience = args.hpo_patience)
225
+
226
+ # if fine-tuning is enabled; fine tune the model on a portion of test samples
227
+ if args.finetuning_samples > 0:
228
+ finetuneSampleN = args.finetuning_samples
229
+ print("[INFO] Finetuning the model on ",finetuneSampleN,"test samples")
230
+ # split test dataset into finetuning and holdout datasets
231
+ all_indices = range(len(test_dataset))
232
+ finetune_indices = random.sample(all_indices, finetuneSampleN)
233
+ holdout_indices = list(set(all_indices) - set(finetune_indices))
234
+ finetune_dataset = test_dataset.subset(finetune_indices)
235
+ holdout_dataset = test_dataset.subset(holdout_indices)
236
+
237
+ # fine tune on the finetuning dataset; freeze the encoders
238
+ finetuner = flexynesis.FineTuner(model,
239
+ finetune_dataset)
240
+ finetuner.run_experiments()
241
+
242
+ # update the model to finetuned model
243
+ model = finetuner.model
244
+ # update the test dataset to exclude finetuning samples
245
+ test_dataset = holdout_dataset
246
+
247
+ # evaluate predictions; (if any supervised learning happened)
248
+ if any([args.target_variables, args.surv_event_var, args.batch_variables]):
249
+ print("[INFO] Computing model evaluation metrics")
250
+ metrics_df = flexynesis.evaluate_wrapper(model.predict(test_dataset), test_dataset,
251
+ surv_event_var=model.surv_event_var,
252
+ surv_time_var=model.surv_time_var)
253
+ metrics_df.to_csv(os.path.join(args.outdir, '.'.join([args.prefix, 'stats.csv'])), header=True, index=False)
254
+
255
+ # print known/predicted labels
256
+ predicted_labels = pd.concat([flexynesis.get_predicted_labels(model.predict(train_dataset), train_dataset, 'train'),
257
+ flexynesis.get_predicted_labels(model.predict(test_dataset), test_dataset, 'test')],
258
+ ignore_index=True)
259
+ predicted_labels.to_csv(os.path.join(args.outdir, '.'.join([args.prefix, 'predicted_labels.csv'])), header=True, index=False)
260
+ # compute feature importance values
261
+ print("[INFO] Computing variable importance scores")
262
+ for var in model.target_variables:
263
+ model.compute_feature_importance(train_dataset, var, steps = 50)
264
+ df_imp = pd.concat([model.feature_importances[x] for x in model.target_variables],
265
+ ignore_index = True)
266
+ df_imp.to_csv(os.path.join(args.outdir, '.'.join([args.prefix, 'feature_importance.csv'])), header=True, index=False)
267
+
268
+ # get sample embeddings and save
269
+ print("[INFO] Extracting sample embeddings")
270
+ embeddings_train = model.transform(train_dataset)
271
+ embeddings_test = model.transform(test_dataset)
272
+
273
+ embeddings_train.to_csv(os.path.join(args.outdir, '.'.join([args.prefix, 'embeddings_train.csv'])), header=True)
274
+ embeddings_test.to_csv(os.path.join(args.outdir, '.'.join([args.prefix, 'embeddings_test.csv'])), header=True)
275
+
276
+ # also filter embeddings to remove batch-associated dims and only keep target-variable associated dims
277
+ if args.batch_variables is not None:
278
+ print("[INFO] Printing filtered embeddings")
279
+ embeddings_train_filtered = flexynesis.remove_batch_associated_variables(data = embeddings_train,
280
+ batch_dict={x: train_dataset.ann[x] for x in model.batch_variables} if model.batch_variables is not None else None,
281
+ target_dict={x: train_dataset.ann[x] for x in model.target_variables},
282
+ variable_types=train_dataset.variable_types)
283
+ # filter test embeddings to keep the same dims as the filtered training embeddings
284
+ embeddings_test_filtered = embeddings_test[embeddings_train_filtered.columns]
285
+
286
+ # save
287
+ embeddings_train_filtered.to_csv(os.path.join(args.outdir, '.'.join([args.prefix, 'embeddings_train.filtered.csv'])), header=True)
288
+ embeddings_test_filtered.to_csv(os.path.join(args.outdir, '.'.join([args.prefix, 'embeddings_test.filtered.csv'])), header=True)
289
+
290
+ # for architectures with decoders; print decoded output layers
291
+ if args.model_class == 'CrossModalPred':
292
+ print("[INFO] Printing decoded output layers")
293
+ output_layers_train = model.decode(train_dataset)
294
+ output_layers_test = model.decode(test_dataset)
295
+ for layer in output_layers_train.keys():
296
+ output_layers_train[layer].to_csv(os.path.join(args.outdir, '.'.join([args.prefix, 'train_decoded', layer, 'csv'])), header=True)
297
+ for layer in output_layers_test.keys():
298
+ output_layers_test[layer].to_csv(os.path.join(args.outdir, '.'.join([args.prefix, 'test_decoded', layer, 'csv'])), header=True)
299
+
300
+
301
+ # evaluate off-the-shelf methods on the main target variable
302
+ if args.evaluate_baseline_performance == 'True':
303
+ print("[INFO] Computing off-the-shelf method performance on first target variable:",model.target_variables[0])
304
+ var = model.target_variables[0]
305
+ metrics = pd.DataFrame()
306
+ if var != model.surv_event_var:
307
+ metrics = flexynesis.evaluate_baseline_performance(train_dataset, test_dataset,
308
+ variable_name = var,
309
+ n_folds=5,
310
+ n_jobs = int(args.threads))
311
+ if model.surv_event_var and model.surv_time_var:
312
+ print("[INFO] Computing off-the-shelf method performance on survival variable:",model.surv_time_var)
313
+ metrics_baseline_survival = flexynesis.evaluate_baseline_survival_performance(train_dataset, test_dataset,
314
+ model.surv_time_var,
315
+ model.surv_event_var,
316
+ n_folds = 5,
317
+ n_jobs = int(args.threads))
318
+ metrics = pd.concat([metrics, metrics_baseline_survival], axis = 0, ignore_index = True)
319
+
320
+ if not metrics.empty:
321
+ metrics.to_csv(os.path.join(args.outdir, '.'.join([args.prefix, 'baseline.stats.csv'])), header=True, index=False)
322
+
323
+ # save the trained model in file
324
+ torch.save(model, os.path.join(args.outdir, '.'.join([args.prefix, 'final_model.pth'])))
325
+
326
+ if __name__ == "__main__":
327
+ main()