chebai 0.0.1.dev0__tar.gz → 0.0.2.dev0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- chebai-0.0.2.dev0/PKG-INFO +52 -0
- chebai-0.0.2.dev0/README.md +89 -0
- chebai-0.0.2.dev0/chebai/__init__.py +30 -0
- chebai-0.0.2.dev0/chebai/__main__.py +10 -0
- chebai-0.0.2.dev0/chebai/callbacks/__init__.py +0 -0
- chebai-0.0.2.dev0/chebai/callbacks/epoch_metrics.py +180 -0
- chebai-0.0.2.dev0/chebai/callbacks/model_checkpoint.py +95 -0
- chebai-0.0.2.dev0/chebai/callbacks/prediction_callback.py +55 -0
- chebai-0.0.2.dev0/chebai/callbacks.py +86 -0
- chebai-0.0.2.dev0/chebai/cli.py +97 -0
- chebai-0.0.2.dev0/chebai/loggers/__init__.py +0 -0
- chebai-0.0.2.dev0/chebai/loggers/custom.py +127 -0
- chebai-0.0.2.dev0/chebai/loss/__init__.py +0 -0
- chebai-0.0.2.dev0/chebai/loss/bce_weighted.py +98 -0
- chebai-0.0.2.dev0/chebai/loss/mixed.py +40 -0
- chebai-0.0.2.dev0/chebai/loss/pretraining.py +48 -0
- chebai-0.0.2.dev0/chebai/loss/semantic.py +532 -0
- chebai-0.0.2.dev0/chebai/models/__init__.py +2 -0
- chebai-0.0.2.dev0/chebai/models/base.py +378 -0
- {chebai-0.0.1.dev0 → chebai-0.0.2.dev0}/chebai/models/chemberta.py +6 -6
- {chebai-0.0.1.dev0 → chebai-0.0.2.dev0}/chebai/models/chemyk.py +4 -4
- chebai-0.0.2.dev0/chebai/models/electra.py +535 -0
- chebai-0.0.2.dev0/chebai/models/external/__init__.py +0 -0
- chebai-0.0.2.dev0/chebai/models/ffn.py +153 -0
- {chebai-0.0.1.dev0 → chebai-0.0.2.dev0}/chebai/models/lnn_model.py +3 -3
- {chebai-0.0.1.dev0 → chebai-0.0.2.dev0}/chebai/models/lstm.py +2 -2
- {chebai-0.0.1.dev0 → chebai-0.0.2.dev0}/chebai/models/recursive.py +3 -4
- {chebai-0.0.1.dev0 → chebai-0.0.2.dev0}/chebai/molecule.py +183 -27
- chebai-0.0.2.dev0/chebai/preprocessing/__init__.py +0 -0
- {chebai-0.0.1.dev0 → chebai-0.0.2.dev0}/chebai/preprocessing/bin/BPE_SWJ/vocab.json +1 -1
- chebai-0.0.2.dev0/chebai/preprocessing/bin/deepsmiles_token/tokens.txt +736 -0
- chebai-0.0.2.dev0/chebai/preprocessing/bin/graph/tokens.txt +376 -0
- chebai-0.0.2.dev0/chebai/preprocessing/bin/graph_properties/tokens.txt +0 -0
- chebai-0.0.1.dev0/chebai/preprocessing/bin/selfies.txt → chebai-0.0.2.dev0/chebai/preprocessing/bin/selfies/tokens.txt +735 -897
- {chebai-0.0.1.dev0/chebai/preprocessing/bin → chebai-0.0.2.dev0/chebai/preprocessing/bin/smiles_token}/tokens.txt +776 -850
- chebai-0.0.2.dev0/chebai/preprocessing/bin/smiles_token_unlabeled/tokens.txt +413 -0
- chebai-0.0.2.dev0/chebai/preprocessing/collate.py +137 -0
- chebai-0.0.2.dev0/chebai/preprocessing/datasets/__init__.py +4 -0
- chebai-0.0.2.dev0/chebai/preprocessing/datasets/base.py +1228 -0
- chebai-0.0.2.dev0/chebai/preprocessing/datasets/chebi.py +1476 -0
- chebai-0.0.2.dev0/chebai/preprocessing/datasets/pubchem.py +1009 -0
- chebai-0.0.2.dev0/chebai/preprocessing/datasets/tox21.py +344 -0
- chebai-0.0.2.dev0/chebai/preprocessing/migration/__init__.py +0 -0
- chebai-0.0.2.dev0/chebai/preprocessing/migration/chebi_data_migration.py +338 -0
- chebai-0.0.2.dev0/chebai/preprocessing/reader.py +332 -0
- chebai-0.0.2.dev0/chebai/preprocessing/structures.py +141 -0
- chebai-0.0.2.dev0/chebai/result/__init__.py +0 -0
- chebai-0.0.2.dev0/chebai/result/analyse_sem.py +721 -0
- {chebai-0.0.1.dev0 → chebai-0.0.2.dev0}/chebai/result/base.py +9 -7
- chebai-0.0.2.dev0/chebai/result/classification.py +105 -0
- chebai-0.0.2.dev0/chebai/result/evaluate_predictions.py +108 -0
- {chebai-0.0.1.dev0 → chebai-0.0.2.dev0}/chebai/result/molplot.py +7 -10
- chebai-0.0.2.dev0/chebai/result/pretraining.py +65 -0
- chebai-0.0.2.dev0/chebai/result/utils.py +235 -0
- {chebai-0.0.1.dev0 → chebai-0.0.2.dev0}/chebai/train.py +151 -17
- chebai-0.0.2.dev0/chebai/trainer/CustomTrainer.py +149 -0
- chebai-0.0.2.dev0/chebai/trainer/__init__.py +0 -0
- chebai-0.0.2.dev0/chebai.egg-info/PKG-INFO +52 -0
- chebai-0.0.2.dev0/chebai.egg-info/SOURCES.txt +97 -0
- {chebai-0.0.1.dev0 → chebai-0.0.2.dev0}/chebai.egg-info/requires.txt +13 -10
- chebai-0.0.2.dev0/chebai.egg-info/top_level.txt +2 -0
- chebai-0.0.2.dev0/setup.py +55 -0
- chebai-0.0.2.dev0/tests/__init__.py +0 -0
- chebai-0.0.2.dev0/tests/integration/__init__.py +3 -0
- chebai-0.0.2.dev0/tests/integration/testChebiData.py +164 -0
- chebai-0.0.2.dev0/tests/integration/testChebiDynamicDataSplits.py +197 -0
- chebai-0.0.2.dev0/tests/integration/testCustomBalancedAccuracyMetric.py +166 -0
- chebai-0.0.2.dev0/tests/integration/testCustomMacroF1Metric.py +224 -0
- chebai-0.0.2.dev0/tests/integration/testPubChemData.py +164 -0
- chebai-0.0.2.dev0/tests/integration/testTox21MolNetData.py +164 -0
- chebai-0.0.2.dev0/tests/unit/__init__.py +4 -0
- chebai-0.0.2.dev0/tests/unit/collators/__init__.py +0 -0
- chebai-0.0.2.dev0/tests/unit/collators/testDefaultCollator.py +65 -0
- chebai-0.0.2.dev0/tests/unit/collators/testRaggedCollator.py +204 -0
- chebai-0.0.2.dev0/tests/unit/dataset_classes/__init__.py +0 -0
- chebai-0.0.2.dev0/tests/unit/dataset_classes/testChEBIOverX.py +125 -0
- chebai-0.0.2.dev0/tests/unit/dataset_classes/testChebiDataExtractor.py +228 -0
- chebai-0.0.2.dev0/tests/unit/dataset_classes/testChebiOverXPartial.py +175 -0
- chebai-0.0.2.dev0/tests/unit/dataset_classes/testChebiTermCallback.py +69 -0
- chebai-0.0.2.dev0/tests/unit/dataset_classes/testDynamicDataset.py +372 -0
- chebai-0.0.2.dev0/tests/unit/dataset_classes/testTox21Challenge.py +128 -0
- chebai-0.0.2.dev0/tests/unit/dataset_classes/testXYBaseDataModule.py +92 -0
- chebai-0.0.2.dev0/tests/unit/mock_data/__init__.py +0 -0
- chebai-0.0.2.dev0/tests/unit/mock_data/ontology_mock_data.py +406 -0
- chebai-0.0.2.dev0/tests/unit/mock_data/tox_mock_data.py +510 -0
- chebai-0.0.2.dev0/tests/unit/readers/__init__.py +0 -0
- chebai-0.0.2.dev0/tests/unit/readers/testChemDataReader.py +107 -0
- chebai-0.0.2.dev0/tests/unit/readers/testDataReader.py +56 -0
- chebai-0.0.2.dev0/tests/unit/readers/testDeepChemDataReader.py +115 -0
- chebai-0.0.2.dev0/tests/unit/readers/testSelfiesReader.py +127 -0
- chebai-0.0.1.dev0/PKG-INFO +0 -8
- chebai-0.0.1.dev0/README.md +0 -29
- chebai-0.0.1.dev0/chebai/__init__.py +0 -3
- chebai-0.0.1.dev0/chebai/__main__.py +0 -4
- chebai-0.0.1.dev0/chebai/cli.py +0 -86
- chebai-0.0.1.dev0/chebai/evaluate.py +0 -35
- chebai-0.0.1.dev0/chebai/experiments.py +0 -355
- chebai-0.0.1.dev0/chebai/models/base.py +0 -226
- chebai-0.0.1.dev0/chebai/models/electra.py +0 -386
- chebai-0.0.1.dev0/chebai/models/graph.py +0 -93
- chebai-0.0.1.dev0/chebai/models/graph_k2.py +0 -59
- chebai-0.0.1.dev0/chebai/models/graphyk.py +0 -86
- chebai-0.0.1.dev0/chebai/preprocessing/collate.py +0 -61
- chebai-0.0.1.dev0/chebai/preprocessing/reader.py +0 -278
- chebai-0.0.1.dev0/chebai/preprocessing/structures.py +0 -97
- chebai-0.0.1.dev0/chebai.egg-info/PKG-INFO +0 -8
- chebai-0.0.1.dev0/chebai.egg-info/SOURCES.txt +0 -38
- chebai-0.0.1.dev0/setup.py +0 -52
- {chebai-0.0.1.dev0 → chebai-0.0.2.dev0}/LICENSE +0 -0
- {chebai-0.0.1.dev0 → chebai-0.0.2.dev0}/chebai/models/strontex.py +3 -3
- {chebai-0.0.1.dev0 → chebai-0.0.2.dev0}/chebai/preprocessing/bin/BPE_SWJ/merges.txt +0 -0
- {chebai-0.0.1.dev0 → chebai-0.0.2.dev0}/chebai/preprocessing/collect_all.py +3 -3
- {chebai-0.0.1.dev0 → chebai-0.0.2.dev0}/chebai/result/prediction_json.py +0 -0
- {chebai-0.0.1.dev0 → chebai-0.0.2.dev0}/chebai.egg-info/dependency_links.txt +0 -0
- /chebai-0.0.1.dev0/chebai.egg-info/top_level.txt → /chebai-0.0.2.dev0/chebai.egg-info/not-zip-safe +0 -0
- {chebai-0.0.1.dev0 → chebai-0.0.2.dev0}/setup.cfg +0 -0
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: chebai
|
|
3
|
+
Version: 0.0.2.dev0
|
|
4
|
+
Home-page:
|
|
5
|
+
Author: MGlauer
|
|
6
|
+
Author-email: martin.glauer@ovgu.de
|
|
7
|
+
Requires-Python: >=3.9, <3.13
|
|
8
|
+
License-File: LICENSE
|
|
9
|
+
Requires-Dist: certifi
|
|
10
|
+
Requires-Dist: idna
|
|
11
|
+
Requires-Dist: joblib
|
|
12
|
+
Requires-Dist: networkx
|
|
13
|
+
Requires-Dist: numpy<2
|
|
14
|
+
Requires-Dist: pandas
|
|
15
|
+
Requires-Dist: python-dateutil
|
|
16
|
+
Requires-Dist: pytz
|
|
17
|
+
Requires-Dist: requests
|
|
18
|
+
Requires-Dist: scikit-learn
|
|
19
|
+
Requires-Dist: scipy
|
|
20
|
+
Requires-Dist: six
|
|
21
|
+
Requires-Dist: threadpoolctl
|
|
22
|
+
Requires-Dist: torch
|
|
23
|
+
Requires-Dist: typing-extensions
|
|
24
|
+
Requires-Dist: urllib3
|
|
25
|
+
Requires-Dist: transformers
|
|
26
|
+
Requires-Dist: fastobo
|
|
27
|
+
Requires-Dist: pysmiles==1.1.2
|
|
28
|
+
Requires-Dist: scikit-network
|
|
29
|
+
Requires-Dist: svgutils
|
|
30
|
+
Requires-Dist: matplotlib
|
|
31
|
+
Requires-Dist: rdkit
|
|
32
|
+
Requires-Dist: selfies
|
|
33
|
+
Requires-Dist: lightning>=2.5
|
|
34
|
+
Requires-Dist: jsonargparse[signatures]>=4.17
|
|
35
|
+
Requires-Dist: omegaconf
|
|
36
|
+
Requires-Dist: seaborn
|
|
37
|
+
Requires-Dist: deepsmiles
|
|
38
|
+
Requires-Dist: iterative-stratification
|
|
39
|
+
Requires-Dist: wandb
|
|
40
|
+
Requires-Dist: chardet
|
|
41
|
+
Requires-Dist: pyyaml
|
|
42
|
+
Requires-Dist: torchmetrics
|
|
43
|
+
Provides-Extra: dev
|
|
44
|
+
Requires-Dist: black; extra == "dev"
|
|
45
|
+
Requires-Dist: isort; extra == "dev"
|
|
46
|
+
Requires-Dist: pre-commit; extra == "dev"
|
|
47
|
+
Dynamic: author
|
|
48
|
+
Dynamic: author-email
|
|
49
|
+
Dynamic: license-file
|
|
50
|
+
Dynamic: provides-extra
|
|
51
|
+
Dynamic: requires-dist
|
|
52
|
+
Dynamic: requires-python
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
# ChEBai
|
|
2
|
+
|
|
3
|
+
ChEBai is a deep learning library designed for the integration of deep learning methods with chemical ontologies, particularly ChEBI.
|
|
4
|
+
The library emphasizes the incorporation of the semantic qualities of the ontology into the learning process.
|
|
5
|
+
|
|
6
|
+
## Installation
|
|
7
|
+
|
|
8
|
+
To install ChEBai, follow these steps:
|
|
9
|
+
|
|
10
|
+
1. Clone the repository:
|
|
11
|
+
```
|
|
12
|
+
git clone https://github.com/ChEB-AI/python-chebai.git
|
|
13
|
+
```
|
|
14
|
+
|
|
15
|
+
2. Install the package:
|
|
16
|
+
|
|
17
|
+
```
|
|
18
|
+
cd python-chebai
|
|
19
|
+
pip install .
|
|
20
|
+
```
|
|
21
|
+
|
|
22
|
+
## Usage
|
|
23
|
+
|
|
24
|
+
The training and inference is abstracted using the Pytorch Lightning modules.
|
|
25
|
+
Here are some CLI commands for the standard functionalities of pretraining, ontology extension, fine-tuning for toxicity and prediction.
|
|
26
|
+
For further details, see the [wiki](https://github.com/ChEB-AI/python-chebai/wiki).
|
|
27
|
+
If you face any problems, please open a new [issue](https://github.com/ChEB-AI/python-chebai/issues/new).
|
|
28
|
+
|
|
29
|
+
### Pretraining
|
|
30
|
+
```
|
|
31
|
+
python -m chebai fit --data.class_path=chebai.preprocessing.datasets.pubchem.PubchemChem --model=configs/model/electra-for-pretraining.yml --trainer=configs/training/pretraining_trainer.yml
|
|
32
|
+
```
|
|
33
|
+
|
|
34
|
+
### Structure-based ontology extension
|
|
35
|
+
```
|
|
36
|
+
python -m chebai fit --trainer=configs/training/default_trainer.yml --model=configs/model/electra.yml --model.pretrained_checkpoint=[path-to-pretrained-model] --model.load_prefix=generator. --data=[path-to-dataset-config] --model.out_dim=[number-of-labels]
|
|
37
|
+
```
|
|
38
|
+
A command with additional options may look like this:
|
|
39
|
+
```
|
|
40
|
+
python3 -m chebai fit --trainer=configs/training/default_trainer.yml --model=configs/model/electra.yml --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --model.pretrained_checkpoint=electra_pretrained.ckpt --model.load_prefix=generator. --data=configs/data/chebi50.yml --model.criterion=configs/loss/bce.yml --data.init_args.batch_size=10 --trainer.logger.init_args.name=chebi50_bce_unweighted --data.init_args.num_workers=9 --model.pass_loss_kwargs=false --data.init_args.chebi_version=231 --data.init_args.data_limit=1000
|
|
41
|
+
```
|
|
42
|
+
|
|
43
|
+
### Fine-tuning for Toxicity prediction
|
|
44
|
+
```
|
|
45
|
+
python -m chebai fit --config=[path-to-your-tox21-config] --trainer.callbacks=configs/training/default_callbacks.yml --model.pretrained_checkpoint=[path-to-pretrained-model]
|
|
46
|
+
```
|
|
47
|
+
|
|
48
|
+
### Predicting classes given SMILES strings
|
|
49
|
+
```
|
|
50
|
+
python3 -m chebai predict_from_file --model=[path-to-model-config] --checkpoint_path=[path-to-model] --input_path={path-to-file-containing-smiles] [--classes_path=[path-to-classes-file]] [--save_to=[path-to-output]]
|
|
51
|
+
```
|
|
52
|
+
The input files should contain a list of line-separated SMILES strings. This generates a CSV file that contains the
|
|
53
|
+
one row for each SMILES string and one column for each class.
|
|
54
|
+
The `classes_path` is the path to the dataset's `raw/classes.txt` file that contains the relationship between model output and ChEBI-IDs.
|
|
55
|
+
|
|
56
|
+
## Evaluation
|
|
57
|
+
|
|
58
|
+
An example for evaluating a model trained on the ontology extension task is given in `tutorials/eval_model_basic.ipynb`.
|
|
59
|
+
It takes in the finetuned model as input for performing the evaluation.
|
|
60
|
+
|
|
61
|
+
## Cross-validation
|
|
62
|
+
You can do inner k-fold cross-validation, i.e., train models on k train-validation splits that all use the same test
|
|
63
|
+
set. For that, you need to specify the total_number of folds as
|
|
64
|
+
```
|
|
65
|
+
--data.init_args.inner_k_folds=K
|
|
66
|
+
```
|
|
67
|
+
and the fold to be used in the current optimisation run as
|
|
68
|
+
```
|
|
69
|
+
--data.init_args.fold_index=I
|
|
70
|
+
```
|
|
71
|
+
To train K models, you need to do K such calls, each with a different `fold_index`. On the first call with a given
|
|
72
|
+
`inner_k_folds`, all folds will be created and stored in the data directory
|
|
73
|
+
|
|
74
|
+
## Note for developers
|
|
75
|
+
|
|
76
|
+
If you have used ChEBai before PR #39, the file structure in which your ChEBI-data is saved has changed. This means that
|
|
77
|
+
datasets will be freshly generated. The data however is the same. If you want to keep the old data (including the old
|
|
78
|
+
splits), you can use a migration script. It copies the old data to the new location for a specific ChEBI class
|
|
79
|
+
(including chebi version and other parameters). The script can be called by specifying the data module from a config
|
|
80
|
+
```
|
|
81
|
+
python chebai/preprocessing/migration/chebi_data_migration.py migrate --datamodule=[path-to-data-config]
|
|
82
|
+
```
|
|
83
|
+
or by specifying the class name (e.g. `ChEBIOver50`) and arguments separately
|
|
84
|
+
```
|
|
85
|
+
python chebai/preprocessing/migration/chebi_data_migration.py migrate --class_name=[data-class] [--chebi_version=[version]]
|
|
86
|
+
```
|
|
87
|
+
The new dataset will by default generate random data splits (with a given seed).
|
|
88
|
+
To reuse a fixed data split, you have to provide the path of the csv file generated during the migration:
|
|
89
|
+
`--data.init_args.splits_file_path=[path-to-processed_data]/splits.csv`
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
# Get the absolute path of the current file's directory
|
|
7
|
+
MODULE_PATH = os.path.abspath(os.path.dirname(__file__))
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class CustomTensor(torch.Tensor):
|
|
11
|
+
"""
|
|
12
|
+
A custom tensor class inheriting from `torch.Tensor`.
|
|
13
|
+
|
|
14
|
+
This class allows for the creation of tensors using the provided data.
|
|
15
|
+
|
|
16
|
+
Attributes:
|
|
17
|
+
data (Any): The data to be converted into a tensor.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def __new__(cls, data: Any) -> "CustomTensor":
|
|
21
|
+
"""
|
|
22
|
+
Creates a new instance of CustomTensor.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
data (Any): The data to be converted into a tensor.
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
CustomTensor: A tensor containing the provided data.
|
|
29
|
+
"""
|
|
30
|
+
return torch.tensor(data)
|
|
File without changes
|
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torchmetrics
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def custom_reduce_fx(input: torch.Tensor) -> torch.Tensor:
|
|
6
|
+
"""
|
|
7
|
+
Custom reduction function for distributed training.
|
|
8
|
+
|
|
9
|
+
Args:
|
|
10
|
+
input (torch.Tensor): The input tensor to be reduced.
|
|
11
|
+
|
|
12
|
+
Returns:
|
|
13
|
+
torch.Tensor: The reduced tensor.
|
|
14
|
+
"""
|
|
15
|
+
print(f"called reduce (device: {input.device})")
|
|
16
|
+
return torch.sum(input, dim=0)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class MacroF1(torchmetrics.Metric):
|
|
20
|
+
"""
|
|
21
|
+
Computes the Macro F1 score, which is the unweighted mean of F1 scores for each class.
|
|
22
|
+
This implementation differs from torchmetrics.classification.MultilabelF1Score in the behaviour for undefined
|
|
23
|
+
values (i.e., classes where TP+FN=0). The torchmetrics implementation sets these classes to a default value.
|
|
24
|
+
Here, the mean is only taken over classes which have at least one positive sample.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
num_labels (int): Number of classes/labels.
|
|
28
|
+
dist_sync_on_step (bool, optional): Synchronize metric state across processes at each forward
|
|
29
|
+
before returning the value at the step. Default: False.
|
|
30
|
+
threshold (float, optional): Threshold for converting predicted probabilities to binary (0, 1) predictions.
|
|
31
|
+
Default: 0.5.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(
|
|
35
|
+
self, num_labels: int, dist_sync_on_step: bool = False, threshold: float = 0.5
|
|
36
|
+
):
|
|
37
|
+
super().__init__(dist_sync_on_step=dist_sync_on_step)
|
|
38
|
+
|
|
39
|
+
self.add_state(
|
|
40
|
+
"true_positives",
|
|
41
|
+
default=torch.zeros(num_labels, dtype=torch.int),
|
|
42
|
+
dist_reduce_fx="sum",
|
|
43
|
+
)
|
|
44
|
+
self.add_state(
|
|
45
|
+
"positive_predictions",
|
|
46
|
+
default=torch.zeros(num_labels, dtype=torch.int),
|
|
47
|
+
dist_reduce_fx="sum",
|
|
48
|
+
)
|
|
49
|
+
self.add_state(
|
|
50
|
+
"positive_labels",
|
|
51
|
+
default=torch.zeros(num_labels, dtype=torch.int),
|
|
52
|
+
dist_reduce_fx="sum",
|
|
53
|
+
)
|
|
54
|
+
self.threshold = threshold
|
|
55
|
+
|
|
56
|
+
def update(self, preds: torch.Tensor, labels: torch.Tensor) -> None:
|
|
57
|
+
"""
|
|
58
|
+
Update the state (TPs, Positive Predictions, Positive labels) with the current batch of predictions and labels.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
preds (torch.Tensor): Predictions from the model.
|
|
62
|
+
labels (torch.Tensor): Ground truth labels.
|
|
63
|
+
"""
|
|
64
|
+
tps = torch.sum(
|
|
65
|
+
torch.logical_and(preds > self.threshold, labels.to(torch.bool)), dim=0
|
|
66
|
+
)
|
|
67
|
+
self.true_positives += tps
|
|
68
|
+
self.positive_predictions += torch.sum(preds > self.threshold, dim=0)
|
|
69
|
+
self.positive_labels += torch.sum(labels, dim=0)
|
|
70
|
+
|
|
71
|
+
def compute(self) -> torch.Tensor:
|
|
72
|
+
"""
|
|
73
|
+
Compute the Macro F1 score.
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
torch.Tensor: The computed Macro F1 score.
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
# ignore classes without positive labels
|
|
80
|
+
# classes with positive labels, but no positive predictions will get a precision of "nan" (0 divided by 0),
|
|
81
|
+
# which is propagated to the classwise_f1 and then turned into 0
|
|
82
|
+
mask = self.positive_labels != 0
|
|
83
|
+
precision = self.true_positives[mask] / self.positive_predictions[mask]
|
|
84
|
+
recall = self.true_positives[mask] / self.positive_labels[mask]
|
|
85
|
+
classwise_f1 = 2 * precision * recall / (precision + recall)
|
|
86
|
+
# if (precision and recall are 0) or (precision is nan), set f1 to 0
|
|
87
|
+
classwise_f1 = classwise_f1.nan_to_num()
|
|
88
|
+
return torch.mean(classwise_f1)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class BalancedAccuracy(torchmetrics.Metric):
|
|
92
|
+
"""
|
|
93
|
+
Computes the Balanced Accuracy, which is the average of true positive rate (TPR) and true negative rate (TNR).
|
|
94
|
+
Useful for imbalanced datasets.
|
|
95
|
+
Balanced Accuracy = (TPR + TNR)/2 = (TP/(TP + FN) + (TN)/(TN + FP))/2
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
num_labels (int): Number of classes/labels.
|
|
99
|
+
dist_sync_on_step (bool, optional): Synchronize metric state across processes at each forward
|
|
100
|
+
before returning the value at the step. Default: False.
|
|
101
|
+
threshold (float, optional): Threshold for converting predicted probabilities to binary (0, 1) predictions.
|
|
102
|
+
Default: 0.5.
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
def __init__(
|
|
106
|
+
self, num_labels: int, dist_sync_on_step: bool = False, threshold: float = 0.5
|
|
107
|
+
):
|
|
108
|
+
super().__init__(dist_sync_on_step=dist_sync_on_step)
|
|
109
|
+
|
|
110
|
+
self.add_state(
|
|
111
|
+
"true_positives",
|
|
112
|
+
default=torch.zeros(num_labels, dtype=torch.int),
|
|
113
|
+
dist_reduce_fx="sum",
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
self.add_state(
|
|
117
|
+
"false_positives",
|
|
118
|
+
default=torch.zeros(num_labels, dtype=torch.int),
|
|
119
|
+
dist_reduce_fx="sum",
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
self.add_state(
|
|
123
|
+
"true_negatives",
|
|
124
|
+
default=torch.zeros(num_labels, dtype=torch.int),
|
|
125
|
+
dist_reduce_fx="sum",
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
self.add_state(
|
|
129
|
+
"false_negatives",
|
|
130
|
+
default=torch.zeros(num_labels, dtype=torch.int),
|
|
131
|
+
dist_reduce_fx="sum",
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
self.threshold = threshold
|
|
135
|
+
|
|
136
|
+
def update(self, preds: torch.Tensor, labels: torch.Tensor) -> None:
|
|
137
|
+
"""
|
|
138
|
+
Update the state (TPs, TNs, FPs, FNs) with the current batch of predictions and labels.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
preds (torch.Tensor): Predictions from the model.
|
|
142
|
+
labels (torch.Tensor): Ground truth labels.
|
|
143
|
+
"""
|
|
144
|
+
|
|
145
|
+
# Size: Batch_size x Num_of_Classes;
|
|
146
|
+
# summing over 1st dimension (dim=0), gives us the True positives per class
|
|
147
|
+
tps = torch.sum(
|
|
148
|
+
torch.logical_and(preds > self.threshold, labels.to(torch.bool)), dim=0
|
|
149
|
+
)
|
|
150
|
+
fps = torch.sum(
|
|
151
|
+
torch.logical_and(preds > self.threshold, ~labels.to(torch.bool)), dim=0
|
|
152
|
+
)
|
|
153
|
+
tns = torch.sum(
|
|
154
|
+
torch.logical_and(preds <= self.threshold, ~labels.to(torch.bool)), dim=0
|
|
155
|
+
)
|
|
156
|
+
fns = torch.sum(
|
|
157
|
+
torch.logical_and(preds <= self.threshold, labels.to(torch.bool)), dim=0
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
# Size: Num_of_Classes;
|
|
161
|
+
self.true_positives += tps
|
|
162
|
+
self.false_positives += fps
|
|
163
|
+
self.true_negatives += tns
|
|
164
|
+
self.false_negatives += fns
|
|
165
|
+
|
|
166
|
+
def compute(self) -> torch.Tensor:
|
|
167
|
+
"""
|
|
168
|
+
Compute the Balanced Accuracy.
|
|
169
|
+
|
|
170
|
+
Returns:
|
|
171
|
+
torch.Tensor: The computed Balanced Accuracy.
|
|
172
|
+
"""
|
|
173
|
+
tpr = self.true_positives / (self.true_positives + self.false_negatives)
|
|
174
|
+
tnr = self.true_negatives / (self.true_negatives + self.false_positives)
|
|
175
|
+
# Convert the nan values to 0
|
|
176
|
+
tpr = tpr.nan_to_num()
|
|
177
|
+
tnr = tnr.nan_to_num()
|
|
178
|
+
|
|
179
|
+
balanced_acc = (tpr + tnr) / 2
|
|
180
|
+
return torch.mean(balanced_acc)
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
from lightning.fabric.utilities.cloud_io import _is_dir
|
|
4
|
+
from lightning.fabric.utilities.types import _PATH
|
|
5
|
+
from lightning.pytorch import LightningModule, Trainer
|
|
6
|
+
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
|
|
7
|
+
from lightning.pytorch.loggers import WandbLogger
|
|
8
|
+
from lightning.pytorch.utilities.rank_zero import rank_zero_info
|
|
9
|
+
from lightning_utilities.core.rank_zero import rank_zero_warn
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class CustomModelCheckpoint(ModelCheckpoint):
|
|
13
|
+
"""
|
|
14
|
+
Custom checkpoint class that resolves checkpoint paths to ensure checkpoints are saved in the same directory
|
|
15
|
+
as other logs when using CustomLogger.
|
|
16
|
+
Inherits from PyTorch Lightning's ModelCheckpoint class.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
|
|
20
|
+
"""
|
|
21
|
+
Setup the directory path for saving checkpoints. If the directory path is not set, it resolves the checkpoint
|
|
22
|
+
directory using the custom logger's directory.
|
|
23
|
+
|
|
24
|
+
Note:
|
|
25
|
+
Same as in parent class, duplicated to be able to call self.__resolve_ckpt_dir
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
trainer (Trainer): The Trainer instance.
|
|
29
|
+
pl_module (LightningModule): The LightningModule instance.
|
|
30
|
+
stage (str): The stage of training (e.g., 'fit').
|
|
31
|
+
"""
|
|
32
|
+
if self.dirpath is not None:
|
|
33
|
+
self.dirpath = None
|
|
34
|
+
dirpath = self.__resolve_ckpt_dir(trainer)
|
|
35
|
+
dirpath = trainer.strategy.broadcast(dirpath)
|
|
36
|
+
self.dirpath = dirpath
|
|
37
|
+
if trainer.is_global_zero and stage == "fit":
|
|
38
|
+
self.__warn_if_dir_not_empty(self.dirpath)
|
|
39
|
+
|
|
40
|
+
def __warn_if_dir_not_empty(self, dirpath: _PATH) -> None:
|
|
41
|
+
"""
|
|
42
|
+
Warn if the checkpoint directory is not empty.
|
|
43
|
+
|
|
44
|
+
Note:
|
|
45
|
+
Same as in parent class, duplicated because method in parent class is not accessible
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
dirpath (_PATH): The path to the checkpoint directory.
|
|
49
|
+
"""
|
|
50
|
+
if (
|
|
51
|
+
self.save_top_k != 0
|
|
52
|
+
and _is_dir(self._fs, dirpath, strict=True)
|
|
53
|
+
and len(self._fs.ls(dirpath)) > 0
|
|
54
|
+
):
|
|
55
|
+
rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
|
|
56
|
+
|
|
57
|
+
def __resolve_ckpt_dir(self, trainer: Trainer) -> _PATH:
|
|
58
|
+
"""
|
|
59
|
+
Resolve the checkpoint directory path, ensuring compatibility with WandbLogger by saving checkpoints
|
|
60
|
+
in the same directory as Wandb logs.
|
|
61
|
+
|
|
62
|
+
Note:
|
|
63
|
+
Overwritten for compatibility with wandb -> saves checkpoints in same dir as wandb logs
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
trainer (Trainer): The Trainer instance.
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
_PATH: The resolved checkpoint directory path.
|
|
70
|
+
"""
|
|
71
|
+
rank_zero_info(f"Resolving checkpoint dir (custom)")
|
|
72
|
+
if self.dirpath is not None:
|
|
73
|
+
# short circuit if dirpath was passed to ModelCheckpoint
|
|
74
|
+
return self.dirpath
|
|
75
|
+
if len(trainer.loggers) > 0:
|
|
76
|
+
if trainer.loggers[0].save_dir is not None:
|
|
77
|
+
save_dir = trainer.loggers[0].save_dir
|
|
78
|
+
else:
|
|
79
|
+
save_dir = trainer.default_root_dir
|
|
80
|
+
name = trainer.loggers[0].name
|
|
81
|
+
version = trainer.loggers[0].version
|
|
82
|
+
version = version if isinstance(version, str) else f"version_{version}"
|
|
83
|
+
logger = trainer.loggers[0]
|
|
84
|
+
if isinstance(logger, WandbLogger) and isinstance(
|
|
85
|
+
logger.experiment.dir, str
|
|
86
|
+
):
|
|
87
|
+
ckpt_path = os.path.join(logger.experiment.dir, "checkpoints")
|
|
88
|
+
else:
|
|
89
|
+
ckpt_path = os.path.join(save_dir, str(name), version, "checkpoints")
|
|
90
|
+
else:
|
|
91
|
+
# if no loggers, use default_root_dir
|
|
92
|
+
ckpt_path = os.path.join(trainer.default_root_dir, "checkpoints")
|
|
93
|
+
|
|
94
|
+
rank_zero_info(f"Now using checkpoint path {ckpt_path}")
|
|
95
|
+
return ckpt_path
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import pickle
|
|
3
|
+
from typing import Any, Literal, Sequence
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from lightning.pytorch import LightningModule, Trainer
|
|
7
|
+
from lightning.pytorch.callbacks import BasePredictionWriter
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class PredictionWriter(BasePredictionWriter):
|
|
11
|
+
"""
|
|
12
|
+
Custom callback for writing predictions to a file at the end of each epoch.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
output_dir (str): The directory where prediction files will be saved.
|
|
16
|
+
write_interval (str): When to write predictions. Options are "batch" or "epoch".
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
output_dir: str,
|
|
22
|
+
write_interval: Literal["batch", "epoch", "batch_and_epoch"],
|
|
23
|
+
):
|
|
24
|
+
super().__init__(write_interval)
|
|
25
|
+
self.output_dir = output_dir
|
|
26
|
+
self.prediction_file_name = "predictions.pkl"
|
|
27
|
+
|
|
28
|
+
def write_on_epoch_end(
|
|
29
|
+
self,
|
|
30
|
+
trainer: Trainer,
|
|
31
|
+
pl_module: LightningModule,
|
|
32
|
+
predictions: Sequence[Any],
|
|
33
|
+
batch_indices: Sequence[Any],
|
|
34
|
+
) -> None:
|
|
35
|
+
"""
|
|
36
|
+
Writes the predictions to a file at the end of the epoch.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
trainer (Trainer): The Trainer instance.
|
|
40
|
+
pl_module (LightningModule): The LightningModule instance.
|
|
41
|
+
predictions (Sequence[Any]): Any sequence of predictions for the epoch.
|
|
42
|
+
batch_indices (Sequence[Any]): Any sequence of batch indices.
|
|
43
|
+
"""
|
|
44
|
+
results = [
|
|
45
|
+
dict(
|
|
46
|
+
ident=row["data"]["idents"][0],
|
|
47
|
+
predictions=torch.sigmoid(row["output"]["logits"]).numpy(),
|
|
48
|
+
labels=row["labels"][0].numpy() if row["labels"] is not None else None,
|
|
49
|
+
)
|
|
50
|
+
for row in predictions
|
|
51
|
+
]
|
|
52
|
+
with open(
|
|
53
|
+
os.path.join(self.output_dir, self.prediction_file_name), "wb"
|
|
54
|
+
) as fout:
|
|
55
|
+
pickle.dump(results, fout)
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
from typing import Any, Dict, List, Literal, Union
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from lightning.pytorch.callbacks import BasePredictionWriter
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class ChebaiPredictionWriter(BasePredictionWriter):
|
|
10
|
+
"""
|
|
11
|
+
A custom prediction writer for saving batch and epoch predictions during model training.
|
|
12
|
+
|
|
13
|
+
This class inherits from `BasePredictionWriter` and is designed to save predictions
|
|
14
|
+
in a specified output directory at specified intervals.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
output_dir (str): The directory where predictions will be saved.
|
|
18
|
+
write_interval (str): The interval at which predictions will be written.
|
|
19
|
+
target_file (str): The name of the file where epoch predictions will be saved (default: "predictions.json").
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
output_dir: str,
|
|
25
|
+
write_interval: Literal["batch", "epoch", "batch_and_epoch"],
|
|
26
|
+
target_file: str = "predictions.json",
|
|
27
|
+
) -> None:
|
|
28
|
+
super().__init__(write_interval)
|
|
29
|
+
self.output_dir = output_dir
|
|
30
|
+
self.target_file = target_file
|
|
31
|
+
|
|
32
|
+
def write_on_batch_end(
|
|
33
|
+
self,
|
|
34
|
+
trainer: Any,
|
|
35
|
+
pl_module: Any,
|
|
36
|
+
prediction: Union[torch.Tensor, List[torch.Tensor]],
|
|
37
|
+
batch_indices: List[int],
|
|
38
|
+
batch: Any,
|
|
39
|
+
batch_idx: int,
|
|
40
|
+
dataloader_idx: int,
|
|
41
|
+
) -> None:
|
|
42
|
+
"""
|
|
43
|
+
Saves batch predictions at the end of each batch.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
trainer (Any): The trainer instance.
|
|
47
|
+
pl_module (Any): The LightningModule instance.
|
|
48
|
+
prediction (Union[torch.Tensor, List[torch.Tensor]]): The prediction output from the model.
|
|
49
|
+
batch_indices (List[int]): The indices of the batch.
|
|
50
|
+
batch (Any): The current batch.
|
|
51
|
+
batch_idx (int): The index of the batch.
|
|
52
|
+
dataloader_idx (int): The index of the dataloader.
|
|
53
|
+
"""
|
|
54
|
+
outpath = os.path.join(self.output_dir, str(dataloader_idx), f"{batch_idx}.pt")
|
|
55
|
+
os.makedirs(os.path.dirname(outpath), exist_ok=True)
|
|
56
|
+
torch.save(prediction, outpath)
|
|
57
|
+
|
|
58
|
+
def write_on_epoch_end(
|
|
59
|
+
self,
|
|
60
|
+
trainer: Any,
|
|
61
|
+
pl_module: Any,
|
|
62
|
+
predictions: List[Dict[str, Any]],
|
|
63
|
+
batch_indices: List[int],
|
|
64
|
+
) -> None:
|
|
65
|
+
"""
|
|
66
|
+
Saves all predictions at the end of each epoch in a JSON file.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
trainer (Any): The trainer instance.
|
|
70
|
+
pl_module (Any): The LightningModule instance.
|
|
71
|
+
predictions (List[Dict[str, Any]]): The list of prediction outputs from the model.
|
|
72
|
+
batch_indices (List[int]): The indices of the batches.
|
|
73
|
+
"""
|
|
74
|
+
pred_list = []
|
|
75
|
+
for p in predictions:
|
|
76
|
+
idents = p["data"]["idents"]
|
|
77
|
+
labels = p["data"]["labels"]
|
|
78
|
+
if labels is not None:
|
|
79
|
+
labels = labels.tolist()
|
|
80
|
+
else:
|
|
81
|
+
labels = [None for _ in idents]
|
|
82
|
+
output = torch.sigmoid(p["output"]["logits"]).tolist()
|
|
83
|
+
for i, l, o in zip(idents, labels, output):
|
|
84
|
+
pred_list.append(dict(ident=i, labels=l, predictions=o))
|
|
85
|
+
with open(os.path.join(self.output_dir, self.target_file), "wt") as fout:
|
|
86
|
+
json.dump(pred_list, fout)
|