chebai 0.0.1.dev0__tar.gz → 0.0.2.dev0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. chebai-0.0.2.dev0/PKG-INFO +52 -0
  2. chebai-0.0.2.dev0/README.md +89 -0
  3. chebai-0.0.2.dev0/chebai/__init__.py +30 -0
  4. chebai-0.0.2.dev0/chebai/__main__.py +10 -0
  5. chebai-0.0.2.dev0/chebai/callbacks/__init__.py +0 -0
  6. chebai-0.0.2.dev0/chebai/callbacks/epoch_metrics.py +180 -0
  7. chebai-0.0.2.dev0/chebai/callbacks/model_checkpoint.py +95 -0
  8. chebai-0.0.2.dev0/chebai/callbacks/prediction_callback.py +55 -0
  9. chebai-0.0.2.dev0/chebai/callbacks.py +86 -0
  10. chebai-0.0.2.dev0/chebai/cli.py +97 -0
  11. chebai-0.0.2.dev0/chebai/loggers/__init__.py +0 -0
  12. chebai-0.0.2.dev0/chebai/loggers/custom.py +127 -0
  13. chebai-0.0.2.dev0/chebai/loss/__init__.py +0 -0
  14. chebai-0.0.2.dev0/chebai/loss/bce_weighted.py +98 -0
  15. chebai-0.0.2.dev0/chebai/loss/mixed.py +40 -0
  16. chebai-0.0.2.dev0/chebai/loss/pretraining.py +48 -0
  17. chebai-0.0.2.dev0/chebai/loss/semantic.py +532 -0
  18. chebai-0.0.2.dev0/chebai/models/__init__.py +2 -0
  19. chebai-0.0.2.dev0/chebai/models/base.py +378 -0
  20. {chebai-0.0.1.dev0 → chebai-0.0.2.dev0}/chebai/models/chemberta.py +6 -6
  21. {chebai-0.0.1.dev0 → chebai-0.0.2.dev0}/chebai/models/chemyk.py +4 -4
  22. chebai-0.0.2.dev0/chebai/models/electra.py +535 -0
  23. chebai-0.0.2.dev0/chebai/models/external/__init__.py +0 -0
  24. chebai-0.0.2.dev0/chebai/models/ffn.py +153 -0
  25. {chebai-0.0.1.dev0 → chebai-0.0.2.dev0}/chebai/models/lnn_model.py +3 -3
  26. {chebai-0.0.1.dev0 → chebai-0.0.2.dev0}/chebai/models/lstm.py +2 -2
  27. {chebai-0.0.1.dev0 → chebai-0.0.2.dev0}/chebai/models/recursive.py +3 -4
  28. {chebai-0.0.1.dev0 → chebai-0.0.2.dev0}/chebai/molecule.py +183 -27
  29. chebai-0.0.2.dev0/chebai/preprocessing/__init__.py +0 -0
  30. {chebai-0.0.1.dev0 → chebai-0.0.2.dev0}/chebai/preprocessing/bin/BPE_SWJ/vocab.json +1 -1
  31. chebai-0.0.2.dev0/chebai/preprocessing/bin/deepsmiles_token/tokens.txt +736 -0
  32. chebai-0.0.2.dev0/chebai/preprocessing/bin/graph/tokens.txt +376 -0
  33. chebai-0.0.2.dev0/chebai/preprocessing/bin/graph_properties/tokens.txt +0 -0
  34. chebai-0.0.1.dev0/chebai/preprocessing/bin/selfies.txt → chebai-0.0.2.dev0/chebai/preprocessing/bin/selfies/tokens.txt +735 -897
  35. {chebai-0.0.1.dev0/chebai/preprocessing/bin → chebai-0.0.2.dev0/chebai/preprocessing/bin/smiles_token}/tokens.txt +776 -850
  36. chebai-0.0.2.dev0/chebai/preprocessing/bin/smiles_token_unlabeled/tokens.txt +413 -0
  37. chebai-0.0.2.dev0/chebai/preprocessing/collate.py +137 -0
  38. chebai-0.0.2.dev0/chebai/preprocessing/datasets/__init__.py +4 -0
  39. chebai-0.0.2.dev0/chebai/preprocessing/datasets/base.py +1228 -0
  40. chebai-0.0.2.dev0/chebai/preprocessing/datasets/chebi.py +1476 -0
  41. chebai-0.0.2.dev0/chebai/preprocessing/datasets/pubchem.py +1009 -0
  42. chebai-0.0.2.dev0/chebai/preprocessing/datasets/tox21.py +344 -0
  43. chebai-0.0.2.dev0/chebai/preprocessing/migration/__init__.py +0 -0
  44. chebai-0.0.2.dev0/chebai/preprocessing/migration/chebi_data_migration.py +338 -0
  45. chebai-0.0.2.dev0/chebai/preprocessing/reader.py +332 -0
  46. chebai-0.0.2.dev0/chebai/preprocessing/structures.py +141 -0
  47. chebai-0.0.2.dev0/chebai/result/__init__.py +0 -0
  48. chebai-0.0.2.dev0/chebai/result/analyse_sem.py +721 -0
  49. {chebai-0.0.1.dev0 → chebai-0.0.2.dev0}/chebai/result/base.py +9 -7
  50. chebai-0.0.2.dev0/chebai/result/classification.py +105 -0
  51. chebai-0.0.2.dev0/chebai/result/evaluate_predictions.py +108 -0
  52. {chebai-0.0.1.dev0 → chebai-0.0.2.dev0}/chebai/result/molplot.py +7 -10
  53. chebai-0.0.2.dev0/chebai/result/pretraining.py +65 -0
  54. chebai-0.0.2.dev0/chebai/result/utils.py +235 -0
  55. {chebai-0.0.1.dev0 → chebai-0.0.2.dev0}/chebai/train.py +151 -17
  56. chebai-0.0.2.dev0/chebai/trainer/CustomTrainer.py +149 -0
  57. chebai-0.0.2.dev0/chebai/trainer/__init__.py +0 -0
  58. chebai-0.0.2.dev0/chebai.egg-info/PKG-INFO +52 -0
  59. chebai-0.0.2.dev0/chebai.egg-info/SOURCES.txt +97 -0
  60. {chebai-0.0.1.dev0 → chebai-0.0.2.dev0}/chebai.egg-info/requires.txt +13 -10
  61. chebai-0.0.2.dev0/chebai.egg-info/top_level.txt +2 -0
  62. chebai-0.0.2.dev0/setup.py +55 -0
  63. chebai-0.0.2.dev0/tests/__init__.py +0 -0
  64. chebai-0.0.2.dev0/tests/integration/__init__.py +3 -0
  65. chebai-0.0.2.dev0/tests/integration/testChebiData.py +164 -0
  66. chebai-0.0.2.dev0/tests/integration/testChebiDynamicDataSplits.py +197 -0
  67. chebai-0.0.2.dev0/tests/integration/testCustomBalancedAccuracyMetric.py +166 -0
  68. chebai-0.0.2.dev0/tests/integration/testCustomMacroF1Metric.py +224 -0
  69. chebai-0.0.2.dev0/tests/integration/testPubChemData.py +164 -0
  70. chebai-0.0.2.dev0/tests/integration/testTox21MolNetData.py +164 -0
  71. chebai-0.0.2.dev0/tests/unit/__init__.py +4 -0
  72. chebai-0.0.2.dev0/tests/unit/collators/__init__.py +0 -0
  73. chebai-0.0.2.dev0/tests/unit/collators/testDefaultCollator.py +65 -0
  74. chebai-0.0.2.dev0/tests/unit/collators/testRaggedCollator.py +204 -0
  75. chebai-0.0.2.dev0/tests/unit/dataset_classes/__init__.py +0 -0
  76. chebai-0.0.2.dev0/tests/unit/dataset_classes/testChEBIOverX.py +125 -0
  77. chebai-0.0.2.dev0/tests/unit/dataset_classes/testChebiDataExtractor.py +228 -0
  78. chebai-0.0.2.dev0/tests/unit/dataset_classes/testChebiOverXPartial.py +175 -0
  79. chebai-0.0.2.dev0/tests/unit/dataset_classes/testChebiTermCallback.py +69 -0
  80. chebai-0.0.2.dev0/tests/unit/dataset_classes/testDynamicDataset.py +372 -0
  81. chebai-0.0.2.dev0/tests/unit/dataset_classes/testTox21Challenge.py +128 -0
  82. chebai-0.0.2.dev0/tests/unit/dataset_classes/testXYBaseDataModule.py +92 -0
  83. chebai-0.0.2.dev0/tests/unit/mock_data/__init__.py +0 -0
  84. chebai-0.0.2.dev0/tests/unit/mock_data/ontology_mock_data.py +406 -0
  85. chebai-0.0.2.dev0/tests/unit/mock_data/tox_mock_data.py +510 -0
  86. chebai-0.0.2.dev0/tests/unit/readers/__init__.py +0 -0
  87. chebai-0.0.2.dev0/tests/unit/readers/testChemDataReader.py +107 -0
  88. chebai-0.0.2.dev0/tests/unit/readers/testDataReader.py +56 -0
  89. chebai-0.0.2.dev0/tests/unit/readers/testDeepChemDataReader.py +115 -0
  90. chebai-0.0.2.dev0/tests/unit/readers/testSelfiesReader.py +127 -0
  91. chebai-0.0.1.dev0/PKG-INFO +0 -8
  92. chebai-0.0.1.dev0/README.md +0 -29
  93. chebai-0.0.1.dev0/chebai/__init__.py +0 -3
  94. chebai-0.0.1.dev0/chebai/__main__.py +0 -4
  95. chebai-0.0.1.dev0/chebai/cli.py +0 -86
  96. chebai-0.0.1.dev0/chebai/evaluate.py +0 -35
  97. chebai-0.0.1.dev0/chebai/experiments.py +0 -355
  98. chebai-0.0.1.dev0/chebai/models/base.py +0 -226
  99. chebai-0.0.1.dev0/chebai/models/electra.py +0 -386
  100. chebai-0.0.1.dev0/chebai/models/graph.py +0 -93
  101. chebai-0.0.1.dev0/chebai/models/graph_k2.py +0 -59
  102. chebai-0.0.1.dev0/chebai/models/graphyk.py +0 -86
  103. chebai-0.0.1.dev0/chebai/preprocessing/collate.py +0 -61
  104. chebai-0.0.1.dev0/chebai/preprocessing/reader.py +0 -278
  105. chebai-0.0.1.dev0/chebai/preprocessing/structures.py +0 -97
  106. chebai-0.0.1.dev0/chebai.egg-info/PKG-INFO +0 -8
  107. chebai-0.0.1.dev0/chebai.egg-info/SOURCES.txt +0 -38
  108. chebai-0.0.1.dev0/setup.py +0 -52
  109. {chebai-0.0.1.dev0 → chebai-0.0.2.dev0}/LICENSE +0 -0
  110. {chebai-0.0.1.dev0 → chebai-0.0.2.dev0}/chebai/models/strontex.py +3 -3
  111. {chebai-0.0.1.dev0 → chebai-0.0.2.dev0}/chebai/preprocessing/bin/BPE_SWJ/merges.txt +0 -0
  112. {chebai-0.0.1.dev0 → chebai-0.0.2.dev0}/chebai/preprocessing/collect_all.py +3 -3
  113. {chebai-0.0.1.dev0 → chebai-0.0.2.dev0}/chebai/result/prediction_json.py +0 -0
  114. {chebai-0.0.1.dev0 → chebai-0.0.2.dev0}/chebai.egg-info/dependency_links.txt +0 -0
  115. /chebai-0.0.1.dev0/chebai.egg-info/top_level.txt → /chebai-0.0.2.dev0/chebai.egg-info/not-zip-safe +0 -0
  116. {chebai-0.0.1.dev0 → chebai-0.0.2.dev0}/setup.cfg +0 -0
@@ -0,0 +1,52 @@
1
+ Metadata-Version: 2.4
2
+ Name: chebai
3
+ Version: 0.0.2.dev0
4
+ Home-page:
5
+ Author: MGlauer
6
+ Author-email: martin.glauer@ovgu.de
7
+ Requires-Python: >=3.9, <3.13
8
+ License-File: LICENSE
9
+ Requires-Dist: certifi
10
+ Requires-Dist: idna
11
+ Requires-Dist: joblib
12
+ Requires-Dist: networkx
13
+ Requires-Dist: numpy<2
14
+ Requires-Dist: pandas
15
+ Requires-Dist: python-dateutil
16
+ Requires-Dist: pytz
17
+ Requires-Dist: requests
18
+ Requires-Dist: scikit-learn
19
+ Requires-Dist: scipy
20
+ Requires-Dist: six
21
+ Requires-Dist: threadpoolctl
22
+ Requires-Dist: torch
23
+ Requires-Dist: typing-extensions
24
+ Requires-Dist: urllib3
25
+ Requires-Dist: transformers
26
+ Requires-Dist: fastobo
27
+ Requires-Dist: pysmiles==1.1.2
28
+ Requires-Dist: scikit-network
29
+ Requires-Dist: svgutils
30
+ Requires-Dist: matplotlib
31
+ Requires-Dist: rdkit
32
+ Requires-Dist: selfies
33
+ Requires-Dist: lightning>=2.5
34
+ Requires-Dist: jsonargparse[signatures]>=4.17
35
+ Requires-Dist: omegaconf
36
+ Requires-Dist: seaborn
37
+ Requires-Dist: deepsmiles
38
+ Requires-Dist: iterative-stratification
39
+ Requires-Dist: wandb
40
+ Requires-Dist: chardet
41
+ Requires-Dist: pyyaml
42
+ Requires-Dist: torchmetrics
43
+ Provides-Extra: dev
44
+ Requires-Dist: black; extra == "dev"
45
+ Requires-Dist: isort; extra == "dev"
46
+ Requires-Dist: pre-commit; extra == "dev"
47
+ Dynamic: author
48
+ Dynamic: author-email
49
+ Dynamic: license-file
50
+ Dynamic: provides-extra
51
+ Dynamic: requires-dist
52
+ Dynamic: requires-python
@@ -0,0 +1,89 @@
1
+ # ChEBai
2
+
3
+ ChEBai is a deep learning library designed for the integration of deep learning methods with chemical ontologies, particularly ChEBI.
4
+ The library emphasizes the incorporation of the semantic qualities of the ontology into the learning process.
5
+
6
+ ## Installation
7
+
8
+ To install ChEBai, follow these steps:
9
+
10
+ 1. Clone the repository:
11
+ ```
12
+ git clone https://github.com/ChEB-AI/python-chebai.git
13
+ ```
14
+
15
+ 2. Install the package:
16
+
17
+ ```
18
+ cd python-chebai
19
+ pip install .
20
+ ```
21
+
22
+ ## Usage
23
+
24
+ The training and inference is abstracted using the Pytorch Lightning modules.
25
+ Here are some CLI commands for the standard functionalities of pretraining, ontology extension, fine-tuning for toxicity and prediction.
26
+ For further details, see the [wiki](https://github.com/ChEB-AI/python-chebai/wiki).
27
+ If you face any problems, please open a new [issue](https://github.com/ChEB-AI/python-chebai/issues/new).
28
+
29
+ ### Pretraining
30
+ ```
31
+ python -m chebai fit --data.class_path=chebai.preprocessing.datasets.pubchem.PubchemChem --model=configs/model/electra-for-pretraining.yml --trainer=configs/training/pretraining_trainer.yml
32
+ ```
33
+
34
+ ### Structure-based ontology extension
35
+ ```
36
+ python -m chebai fit --trainer=configs/training/default_trainer.yml --model=configs/model/electra.yml --model.pretrained_checkpoint=[path-to-pretrained-model] --model.load_prefix=generator. --data=[path-to-dataset-config] --model.out_dim=[number-of-labels]
37
+ ```
38
+ A command with additional options may look like this:
39
+ ```
40
+ python3 -m chebai fit --trainer=configs/training/default_trainer.yml --model=configs/model/electra.yml --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --model.pretrained_checkpoint=electra_pretrained.ckpt --model.load_prefix=generator. --data=configs/data/chebi50.yml --model.criterion=configs/loss/bce.yml --data.init_args.batch_size=10 --trainer.logger.init_args.name=chebi50_bce_unweighted --data.init_args.num_workers=9 --model.pass_loss_kwargs=false --data.init_args.chebi_version=231 --data.init_args.data_limit=1000
41
+ ```
42
+
43
+ ### Fine-tuning for Toxicity prediction
44
+ ```
45
+ python -m chebai fit --config=[path-to-your-tox21-config] --trainer.callbacks=configs/training/default_callbacks.yml --model.pretrained_checkpoint=[path-to-pretrained-model]
46
+ ```
47
+
48
+ ### Predicting classes given SMILES strings
49
+ ```
50
+ python3 -m chebai predict_from_file --model=[path-to-model-config] --checkpoint_path=[path-to-model] --input_path={path-to-file-containing-smiles] [--classes_path=[path-to-classes-file]] [--save_to=[path-to-output]]
51
+ ```
52
+ The input files should contain a list of line-separated SMILES strings. This generates a CSV file that contains the
53
+ one row for each SMILES string and one column for each class.
54
+ The `classes_path` is the path to the dataset's `raw/classes.txt` file that contains the relationship between model output and ChEBI-IDs.
55
+
56
+ ## Evaluation
57
+
58
+ An example for evaluating a model trained on the ontology extension task is given in `tutorials/eval_model_basic.ipynb`.
59
+ It takes in the finetuned model as input for performing the evaluation.
60
+
61
+ ## Cross-validation
62
+ You can do inner k-fold cross-validation, i.e., train models on k train-validation splits that all use the same test
63
+ set. For that, you need to specify the total_number of folds as
64
+ ```
65
+ --data.init_args.inner_k_folds=K
66
+ ```
67
+ and the fold to be used in the current optimisation run as
68
+ ```
69
+ --data.init_args.fold_index=I
70
+ ```
71
+ To train K models, you need to do K such calls, each with a different `fold_index`. On the first call with a given
72
+ `inner_k_folds`, all folds will be created and stored in the data directory
73
+
74
+ ## Note for developers
75
+
76
+ If you have used ChEBai before PR #39, the file structure in which your ChEBI-data is saved has changed. This means that
77
+ datasets will be freshly generated. The data however is the same. If you want to keep the old data (including the old
78
+ splits), you can use a migration script. It copies the old data to the new location for a specific ChEBI class
79
+ (including chebi version and other parameters). The script can be called by specifying the data module from a config
80
+ ```
81
+ python chebai/preprocessing/migration/chebi_data_migration.py migrate --datamodule=[path-to-data-config]
82
+ ```
83
+ or by specifying the class name (e.g. `ChEBIOver50`) and arguments separately
84
+ ```
85
+ python chebai/preprocessing/migration/chebi_data_migration.py migrate --class_name=[data-class] [--chebi_version=[version]]
86
+ ```
87
+ The new dataset will by default generate random data splits (with a given seed).
88
+ To reuse a fixed data split, you have to provide the path of the csv file generated during the migration:
89
+ `--data.init_args.splits_file_path=[path-to-processed_data]/splits.csv`
@@ -0,0 +1,30 @@
1
+ import os
2
+ from typing import Any
3
+
4
+ import torch
5
+
6
+ # Get the absolute path of the current file's directory
7
+ MODULE_PATH = os.path.abspath(os.path.dirname(__file__))
8
+
9
+
10
+ class CustomTensor(torch.Tensor):
11
+ """
12
+ A custom tensor class inheriting from `torch.Tensor`.
13
+
14
+ This class allows for the creation of tensors using the provided data.
15
+
16
+ Attributes:
17
+ data (Any): The data to be converted into a tensor.
18
+ """
19
+
20
+ def __new__(cls, data: Any) -> "CustomTensor":
21
+ """
22
+ Creates a new instance of CustomTensor.
23
+
24
+ Args:
25
+ data (Any): The data to be converted into a tensor.
26
+
27
+ Returns:
28
+ CustomTensor: A tensor containing the provided data.
29
+ """
30
+ return torch.tensor(data)
@@ -0,0 +1,10 @@
1
+ from chebai.cli import cli
2
+
3
+ if __name__ == "__main__":
4
+ """
5
+ Entry point for the CLI application.
6
+
7
+ This script calls the `cli` function from the `chebai.cli` module
8
+ when executed as the main program.
9
+ """
10
+ cli()
File without changes
@@ -0,0 +1,180 @@
1
+ import torch
2
+ import torchmetrics
3
+
4
+
5
+ def custom_reduce_fx(input: torch.Tensor) -> torch.Tensor:
6
+ """
7
+ Custom reduction function for distributed training.
8
+
9
+ Args:
10
+ input (torch.Tensor): The input tensor to be reduced.
11
+
12
+ Returns:
13
+ torch.Tensor: The reduced tensor.
14
+ """
15
+ print(f"called reduce (device: {input.device})")
16
+ return torch.sum(input, dim=0)
17
+
18
+
19
+ class MacroF1(torchmetrics.Metric):
20
+ """
21
+ Computes the Macro F1 score, which is the unweighted mean of F1 scores for each class.
22
+ This implementation differs from torchmetrics.classification.MultilabelF1Score in the behaviour for undefined
23
+ values (i.e., classes where TP+FN=0). The torchmetrics implementation sets these classes to a default value.
24
+ Here, the mean is only taken over classes which have at least one positive sample.
25
+
26
+ Args:
27
+ num_labels (int): Number of classes/labels.
28
+ dist_sync_on_step (bool, optional): Synchronize metric state across processes at each forward
29
+ before returning the value at the step. Default: False.
30
+ threshold (float, optional): Threshold for converting predicted probabilities to binary (0, 1) predictions.
31
+ Default: 0.5.
32
+ """
33
+
34
+ def __init__(
35
+ self, num_labels: int, dist_sync_on_step: bool = False, threshold: float = 0.5
36
+ ):
37
+ super().__init__(dist_sync_on_step=dist_sync_on_step)
38
+
39
+ self.add_state(
40
+ "true_positives",
41
+ default=torch.zeros(num_labels, dtype=torch.int),
42
+ dist_reduce_fx="sum",
43
+ )
44
+ self.add_state(
45
+ "positive_predictions",
46
+ default=torch.zeros(num_labels, dtype=torch.int),
47
+ dist_reduce_fx="sum",
48
+ )
49
+ self.add_state(
50
+ "positive_labels",
51
+ default=torch.zeros(num_labels, dtype=torch.int),
52
+ dist_reduce_fx="sum",
53
+ )
54
+ self.threshold = threshold
55
+
56
+ def update(self, preds: torch.Tensor, labels: torch.Tensor) -> None:
57
+ """
58
+ Update the state (TPs, Positive Predictions, Positive labels) with the current batch of predictions and labels.
59
+
60
+ Args:
61
+ preds (torch.Tensor): Predictions from the model.
62
+ labels (torch.Tensor): Ground truth labels.
63
+ """
64
+ tps = torch.sum(
65
+ torch.logical_and(preds > self.threshold, labels.to(torch.bool)), dim=0
66
+ )
67
+ self.true_positives += tps
68
+ self.positive_predictions += torch.sum(preds > self.threshold, dim=0)
69
+ self.positive_labels += torch.sum(labels, dim=0)
70
+
71
+ def compute(self) -> torch.Tensor:
72
+ """
73
+ Compute the Macro F1 score.
74
+
75
+ Returns:
76
+ torch.Tensor: The computed Macro F1 score.
77
+ """
78
+
79
+ # ignore classes without positive labels
80
+ # classes with positive labels, but no positive predictions will get a precision of "nan" (0 divided by 0),
81
+ # which is propagated to the classwise_f1 and then turned into 0
82
+ mask = self.positive_labels != 0
83
+ precision = self.true_positives[mask] / self.positive_predictions[mask]
84
+ recall = self.true_positives[mask] / self.positive_labels[mask]
85
+ classwise_f1 = 2 * precision * recall / (precision + recall)
86
+ # if (precision and recall are 0) or (precision is nan), set f1 to 0
87
+ classwise_f1 = classwise_f1.nan_to_num()
88
+ return torch.mean(classwise_f1)
89
+
90
+
91
+ class BalancedAccuracy(torchmetrics.Metric):
92
+ """
93
+ Computes the Balanced Accuracy, which is the average of true positive rate (TPR) and true negative rate (TNR).
94
+ Useful for imbalanced datasets.
95
+ Balanced Accuracy = (TPR + TNR)/2 = (TP/(TP + FN) + (TN)/(TN + FP))/2
96
+
97
+ Args:
98
+ num_labels (int): Number of classes/labels.
99
+ dist_sync_on_step (bool, optional): Synchronize metric state across processes at each forward
100
+ before returning the value at the step. Default: False.
101
+ threshold (float, optional): Threshold for converting predicted probabilities to binary (0, 1) predictions.
102
+ Default: 0.5.
103
+ """
104
+
105
+ def __init__(
106
+ self, num_labels: int, dist_sync_on_step: bool = False, threshold: float = 0.5
107
+ ):
108
+ super().__init__(dist_sync_on_step=dist_sync_on_step)
109
+
110
+ self.add_state(
111
+ "true_positives",
112
+ default=torch.zeros(num_labels, dtype=torch.int),
113
+ dist_reduce_fx="sum",
114
+ )
115
+
116
+ self.add_state(
117
+ "false_positives",
118
+ default=torch.zeros(num_labels, dtype=torch.int),
119
+ dist_reduce_fx="sum",
120
+ )
121
+
122
+ self.add_state(
123
+ "true_negatives",
124
+ default=torch.zeros(num_labels, dtype=torch.int),
125
+ dist_reduce_fx="sum",
126
+ )
127
+
128
+ self.add_state(
129
+ "false_negatives",
130
+ default=torch.zeros(num_labels, dtype=torch.int),
131
+ dist_reduce_fx="sum",
132
+ )
133
+
134
+ self.threshold = threshold
135
+
136
+ def update(self, preds: torch.Tensor, labels: torch.Tensor) -> None:
137
+ """
138
+ Update the state (TPs, TNs, FPs, FNs) with the current batch of predictions and labels.
139
+
140
+ Args:
141
+ preds (torch.Tensor): Predictions from the model.
142
+ labels (torch.Tensor): Ground truth labels.
143
+ """
144
+
145
+ # Size: Batch_size x Num_of_Classes;
146
+ # summing over 1st dimension (dim=0), gives us the True positives per class
147
+ tps = torch.sum(
148
+ torch.logical_and(preds > self.threshold, labels.to(torch.bool)), dim=0
149
+ )
150
+ fps = torch.sum(
151
+ torch.logical_and(preds > self.threshold, ~labels.to(torch.bool)), dim=0
152
+ )
153
+ tns = torch.sum(
154
+ torch.logical_and(preds <= self.threshold, ~labels.to(torch.bool)), dim=0
155
+ )
156
+ fns = torch.sum(
157
+ torch.logical_and(preds <= self.threshold, labels.to(torch.bool)), dim=0
158
+ )
159
+
160
+ # Size: Num_of_Classes;
161
+ self.true_positives += tps
162
+ self.false_positives += fps
163
+ self.true_negatives += tns
164
+ self.false_negatives += fns
165
+
166
+ def compute(self) -> torch.Tensor:
167
+ """
168
+ Compute the Balanced Accuracy.
169
+
170
+ Returns:
171
+ torch.Tensor: The computed Balanced Accuracy.
172
+ """
173
+ tpr = self.true_positives / (self.true_positives + self.false_negatives)
174
+ tnr = self.true_negatives / (self.true_negatives + self.false_positives)
175
+ # Convert the nan values to 0
176
+ tpr = tpr.nan_to_num()
177
+ tnr = tnr.nan_to_num()
178
+
179
+ balanced_acc = (tpr + tnr) / 2
180
+ return torch.mean(balanced_acc)
@@ -0,0 +1,95 @@
1
+ import os
2
+
3
+ from lightning.fabric.utilities.cloud_io import _is_dir
4
+ from lightning.fabric.utilities.types import _PATH
5
+ from lightning.pytorch import LightningModule, Trainer
6
+ from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
7
+ from lightning.pytorch.loggers import WandbLogger
8
+ from lightning.pytorch.utilities.rank_zero import rank_zero_info
9
+ from lightning_utilities.core.rank_zero import rank_zero_warn
10
+
11
+
12
+ class CustomModelCheckpoint(ModelCheckpoint):
13
+ """
14
+ Custom checkpoint class that resolves checkpoint paths to ensure checkpoints are saved in the same directory
15
+ as other logs when using CustomLogger.
16
+ Inherits from PyTorch Lightning's ModelCheckpoint class.
17
+ """
18
+
19
+ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
20
+ """
21
+ Setup the directory path for saving checkpoints. If the directory path is not set, it resolves the checkpoint
22
+ directory using the custom logger's directory.
23
+
24
+ Note:
25
+ Same as in parent class, duplicated to be able to call self.__resolve_ckpt_dir
26
+
27
+ Args:
28
+ trainer (Trainer): The Trainer instance.
29
+ pl_module (LightningModule): The LightningModule instance.
30
+ stage (str): The stage of training (e.g., 'fit').
31
+ """
32
+ if self.dirpath is not None:
33
+ self.dirpath = None
34
+ dirpath = self.__resolve_ckpt_dir(trainer)
35
+ dirpath = trainer.strategy.broadcast(dirpath)
36
+ self.dirpath = dirpath
37
+ if trainer.is_global_zero and stage == "fit":
38
+ self.__warn_if_dir_not_empty(self.dirpath)
39
+
40
+ def __warn_if_dir_not_empty(self, dirpath: _PATH) -> None:
41
+ """
42
+ Warn if the checkpoint directory is not empty.
43
+
44
+ Note:
45
+ Same as in parent class, duplicated because method in parent class is not accessible
46
+
47
+ Args:
48
+ dirpath (_PATH): The path to the checkpoint directory.
49
+ """
50
+ if (
51
+ self.save_top_k != 0
52
+ and _is_dir(self._fs, dirpath, strict=True)
53
+ and len(self._fs.ls(dirpath)) > 0
54
+ ):
55
+ rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
56
+
57
+ def __resolve_ckpt_dir(self, trainer: Trainer) -> _PATH:
58
+ """
59
+ Resolve the checkpoint directory path, ensuring compatibility with WandbLogger by saving checkpoints
60
+ in the same directory as Wandb logs.
61
+
62
+ Note:
63
+ Overwritten for compatibility with wandb -> saves checkpoints in same dir as wandb logs
64
+
65
+ Args:
66
+ trainer (Trainer): The Trainer instance.
67
+
68
+ Returns:
69
+ _PATH: The resolved checkpoint directory path.
70
+ """
71
+ rank_zero_info(f"Resolving checkpoint dir (custom)")
72
+ if self.dirpath is not None:
73
+ # short circuit if dirpath was passed to ModelCheckpoint
74
+ return self.dirpath
75
+ if len(trainer.loggers) > 0:
76
+ if trainer.loggers[0].save_dir is not None:
77
+ save_dir = trainer.loggers[0].save_dir
78
+ else:
79
+ save_dir = trainer.default_root_dir
80
+ name = trainer.loggers[0].name
81
+ version = trainer.loggers[0].version
82
+ version = version if isinstance(version, str) else f"version_{version}"
83
+ logger = trainer.loggers[0]
84
+ if isinstance(logger, WandbLogger) and isinstance(
85
+ logger.experiment.dir, str
86
+ ):
87
+ ckpt_path = os.path.join(logger.experiment.dir, "checkpoints")
88
+ else:
89
+ ckpt_path = os.path.join(save_dir, str(name), version, "checkpoints")
90
+ else:
91
+ # if no loggers, use default_root_dir
92
+ ckpt_path = os.path.join(trainer.default_root_dir, "checkpoints")
93
+
94
+ rank_zero_info(f"Now using checkpoint path {ckpt_path}")
95
+ return ckpt_path
@@ -0,0 +1,55 @@
1
+ import os
2
+ import pickle
3
+ from typing import Any, Literal, Sequence
4
+
5
+ import torch
6
+ from lightning.pytorch import LightningModule, Trainer
7
+ from lightning.pytorch.callbacks import BasePredictionWriter
8
+
9
+
10
+ class PredictionWriter(BasePredictionWriter):
11
+ """
12
+ Custom callback for writing predictions to a file at the end of each epoch.
13
+
14
+ Args:
15
+ output_dir (str): The directory where prediction files will be saved.
16
+ write_interval (str): When to write predictions. Options are "batch" or "epoch".
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ output_dir: str,
22
+ write_interval: Literal["batch", "epoch", "batch_and_epoch"],
23
+ ):
24
+ super().__init__(write_interval)
25
+ self.output_dir = output_dir
26
+ self.prediction_file_name = "predictions.pkl"
27
+
28
+ def write_on_epoch_end(
29
+ self,
30
+ trainer: Trainer,
31
+ pl_module: LightningModule,
32
+ predictions: Sequence[Any],
33
+ batch_indices: Sequence[Any],
34
+ ) -> None:
35
+ """
36
+ Writes the predictions to a file at the end of the epoch.
37
+
38
+ Args:
39
+ trainer (Trainer): The Trainer instance.
40
+ pl_module (LightningModule): The LightningModule instance.
41
+ predictions (Sequence[Any]): Any sequence of predictions for the epoch.
42
+ batch_indices (Sequence[Any]): Any sequence of batch indices.
43
+ """
44
+ results = [
45
+ dict(
46
+ ident=row["data"]["idents"][0],
47
+ predictions=torch.sigmoid(row["output"]["logits"]).numpy(),
48
+ labels=row["labels"][0].numpy() if row["labels"] is not None else None,
49
+ )
50
+ for row in predictions
51
+ ]
52
+ with open(
53
+ os.path.join(self.output_dir, self.prediction_file_name), "wb"
54
+ ) as fout:
55
+ pickle.dump(results, fout)
@@ -0,0 +1,86 @@
1
+ import json
2
+ import os
3
+ from typing import Any, Dict, List, Literal, Union
4
+
5
+ import torch
6
+ from lightning.pytorch.callbacks import BasePredictionWriter
7
+
8
+
9
+ class ChebaiPredictionWriter(BasePredictionWriter):
10
+ """
11
+ A custom prediction writer for saving batch and epoch predictions during model training.
12
+
13
+ This class inherits from `BasePredictionWriter` and is designed to save predictions
14
+ in a specified output directory at specified intervals.
15
+
16
+ Args:
17
+ output_dir (str): The directory where predictions will be saved.
18
+ write_interval (str): The interval at which predictions will be written.
19
+ target_file (str): The name of the file where epoch predictions will be saved (default: "predictions.json").
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ output_dir: str,
25
+ write_interval: Literal["batch", "epoch", "batch_and_epoch"],
26
+ target_file: str = "predictions.json",
27
+ ) -> None:
28
+ super().__init__(write_interval)
29
+ self.output_dir = output_dir
30
+ self.target_file = target_file
31
+
32
+ def write_on_batch_end(
33
+ self,
34
+ trainer: Any,
35
+ pl_module: Any,
36
+ prediction: Union[torch.Tensor, List[torch.Tensor]],
37
+ batch_indices: List[int],
38
+ batch: Any,
39
+ batch_idx: int,
40
+ dataloader_idx: int,
41
+ ) -> None:
42
+ """
43
+ Saves batch predictions at the end of each batch.
44
+
45
+ Args:
46
+ trainer (Any): The trainer instance.
47
+ pl_module (Any): The LightningModule instance.
48
+ prediction (Union[torch.Tensor, List[torch.Tensor]]): The prediction output from the model.
49
+ batch_indices (List[int]): The indices of the batch.
50
+ batch (Any): The current batch.
51
+ batch_idx (int): The index of the batch.
52
+ dataloader_idx (int): The index of the dataloader.
53
+ """
54
+ outpath = os.path.join(self.output_dir, str(dataloader_idx), f"{batch_idx}.pt")
55
+ os.makedirs(os.path.dirname(outpath), exist_ok=True)
56
+ torch.save(prediction, outpath)
57
+
58
+ def write_on_epoch_end(
59
+ self,
60
+ trainer: Any,
61
+ pl_module: Any,
62
+ predictions: List[Dict[str, Any]],
63
+ batch_indices: List[int],
64
+ ) -> None:
65
+ """
66
+ Saves all predictions at the end of each epoch in a JSON file.
67
+
68
+ Args:
69
+ trainer (Any): The trainer instance.
70
+ pl_module (Any): The LightningModule instance.
71
+ predictions (List[Dict[str, Any]]): The list of prediction outputs from the model.
72
+ batch_indices (List[int]): The indices of the batches.
73
+ """
74
+ pred_list = []
75
+ for p in predictions:
76
+ idents = p["data"]["idents"]
77
+ labels = p["data"]["labels"]
78
+ if labels is not None:
79
+ labels = labels.tolist()
80
+ else:
81
+ labels = [None for _ in idents]
82
+ output = torch.sigmoid(p["output"]["logits"]).tolist()
83
+ for i, l, o in zip(idents, labels, output):
84
+ pred_list.append(dict(ident=i, labels=l, predictions=o))
85
+ with open(os.path.join(self.output_dir, self.target_file), "wt") as fout:
86
+ json.dump(pred_list, fout)