chebai 0.0.2.dev0__tar.gz → 1.0.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {chebai-0.0.2.dev0 → chebai-1.0.1}/PKG-INFO +2 -2
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/models/base.py +5 -4
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/preprocessing/datasets/base.py +49 -34
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/preprocessing/datasets/chebi.py +2 -2
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/preprocessing/reader.py +54 -27
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/result/analyse_sem.py +66 -40
- chebai-1.0.1/chebai/result/generate_class_properties.py +200 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/result/utils.py +83 -5
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai.egg-info/PKG-INFO +2 -2
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai.egg-info/SOURCES.txt +1 -1
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai.egg-info/requires.txt +1 -1
- {chebai-0.0.2.dev0 → chebai-1.0.1}/setup.py +2 -2
- {chebai-0.0.2.dev0 → chebai-1.0.1}/tests/unit/readers/testChemDataReader.py +62 -9
- {chebai-0.0.2.dev0 → chebai-1.0.1}/tests/unit/readers/testDeepChemDataReader.py +9 -7
- {chebai-0.0.2.dev0 → chebai-1.0.1}/tests/unit/readers/testSelfiesReader.py +8 -6
- chebai-0.0.2.dev0/chebai/preprocessing/bin/graph_properties/tokens.txt +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/LICENSE +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/README.md +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/__init__.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/__main__.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/callbacks/__init__.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/callbacks/epoch_metrics.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/callbacks/model_checkpoint.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/callbacks/prediction_callback.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/callbacks.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/cli.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/loggers/__init__.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/loggers/custom.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/loss/__init__.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/loss/bce_weighted.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/loss/mixed.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/loss/pretraining.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/loss/semantic.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/models/__init__.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/models/chemberta.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/models/chemyk.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/models/electra.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/models/external/__init__.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/models/ffn.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/models/lnn_model.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/models/lstm.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/models/recursive.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/models/strontex.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/molecule.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/preprocessing/__init__.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/preprocessing/bin/BPE_SWJ/merges.txt +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/preprocessing/bin/BPE_SWJ/vocab.json +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/preprocessing/bin/deepsmiles_token/tokens.txt +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/preprocessing/bin/graph/tokens.txt +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/preprocessing/bin/selfies/tokens.txt +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/preprocessing/bin/smiles_token/tokens.txt +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/preprocessing/bin/smiles_token_unlabeled/tokens.txt +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/preprocessing/collate.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/preprocessing/collect_all.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/preprocessing/datasets/__init__.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/preprocessing/datasets/pubchem.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/preprocessing/datasets/tox21.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/preprocessing/migration/__init__.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/preprocessing/migration/chebi_data_migration.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/preprocessing/structures.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/result/__init__.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/result/base.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/result/classification.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/result/evaluate_predictions.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/result/molplot.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/result/prediction_json.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/result/pretraining.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/train.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/trainer/CustomTrainer.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/trainer/__init__.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai.egg-info/dependency_links.txt +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai.egg-info/not-zip-safe +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai.egg-info/top_level.txt +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/setup.cfg +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/tests/__init__.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/tests/integration/__init__.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/tests/integration/testChebiData.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/tests/integration/testChebiDynamicDataSplits.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/tests/integration/testCustomBalancedAccuracyMetric.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/tests/integration/testCustomMacroF1Metric.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/tests/integration/testPubChemData.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/tests/integration/testTox21MolNetData.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/tests/unit/__init__.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/tests/unit/collators/__init__.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/tests/unit/collators/testDefaultCollator.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/tests/unit/collators/testRaggedCollator.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/tests/unit/dataset_classes/__init__.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/tests/unit/dataset_classes/testChEBIOverX.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/tests/unit/dataset_classes/testChebiDataExtractor.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/tests/unit/dataset_classes/testChebiOverXPartial.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/tests/unit/dataset_classes/testChebiTermCallback.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/tests/unit/dataset_classes/testDynamicDataset.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/tests/unit/dataset_classes/testTox21Challenge.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/tests/unit/dataset_classes/testXYBaseDataModule.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/tests/unit/mock_data/__init__.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/tests/unit/mock_data/ontology_mock_data.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/tests/unit/mock_data/tox_mock_data.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/tests/unit/readers/__init__.py +0 -0
- {chebai-0.0.2.dev0 → chebai-1.0.1}/tests/unit/readers/testDataReader.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: chebai
|
|
3
|
-
Version:
|
|
3
|
+
Version: 1.0.1
|
|
4
4
|
Home-page:
|
|
5
5
|
Author: MGlauer
|
|
6
6
|
Author-email: martin.glauer@ovgu.de
|
|
@@ -10,7 +10,7 @@ Requires-Dist: certifi
|
|
|
10
10
|
Requires-Dist: idna
|
|
11
11
|
Requires-Dist: joblib
|
|
12
12
|
Requires-Dist: networkx
|
|
13
|
-
Requires-Dist: numpy
|
|
13
|
+
Requires-Dist: numpy
|
|
14
14
|
Requires-Dist: pandas
|
|
15
15
|
Requires-Dist: python-dateutil
|
|
16
16
|
Requires-Dist: pytz
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
import logging
|
|
2
|
-
from
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from typing import Any, Dict, Iterable, Optional, Union
|
|
3
4
|
|
|
4
5
|
import torch
|
|
5
6
|
from lightning.pytorch.core.module import LightningModule
|
|
6
|
-
from torchmetrics import Metric
|
|
7
7
|
|
|
8
8
|
from chebai.preprocessing.structures import XYData
|
|
9
9
|
|
|
@@ -12,7 +12,7 @@ logging.getLogger("pysmiles").setLevel(logging.CRITICAL)
|
|
|
12
12
|
_MODEL_REGISTRY = dict()
|
|
13
13
|
|
|
14
14
|
|
|
15
|
-
class ChebaiBaseNet(LightningModule):
|
|
15
|
+
class ChebaiBaseNet(LightningModule, ABC):
|
|
16
16
|
"""
|
|
17
17
|
Base class for Chebai neural network models inheriting from PyTorch Lightning's LightningModule.
|
|
18
18
|
|
|
@@ -353,6 +353,7 @@ class ChebaiBaseNet(LightningModule):
|
|
|
353
353
|
logger=True,
|
|
354
354
|
)
|
|
355
355
|
|
|
356
|
+
@abstractmethod
|
|
356
357
|
def forward(self, x: Dict[str, Any]) -> torch.Tensor:
|
|
357
358
|
"""
|
|
358
359
|
Defines the forward pass.
|
|
@@ -363,7 +364,7 @@ class ChebaiBaseNet(LightningModule):
|
|
|
363
364
|
Returns:
|
|
364
365
|
torch.Tensor: The model output.
|
|
365
366
|
"""
|
|
366
|
-
|
|
367
|
+
pass
|
|
367
368
|
|
|
368
369
|
def configure_optimizers(self, **kwargs) -> torch.optim.Optimizer:
|
|
369
370
|
"""
|
|
@@ -29,7 +29,8 @@ class XYBaseDataModule(LightningDataModule):
|
|
|
29
29
|
|
|
30
30
|
Args:
|
|
31
31
|
batch_size (int): The batch size for data loading. Default is 1.
|
|
32
|
-
|
|
32
|
+
test_split (float): The ratio of test data to total data. Default is 0.1.
|
|
33
|
+
validation_split (float): The ratio of validation data to total data. Default is 0.05.
|
|
33
34
|
reader_kwargs (dict): Additional keyword arguments to be passed to the data reader. Default is None.
|
|
34
35
|
prediction_kind (str): The kind of prediction to be performed (only relevant for the predict_dataloader). Default is "test".
|
|
35
36
|
data_limit (Optional[int]): The maximum number of data samples to load. If set to None, the complete dataset will be used. Default is None.
|
|
@@ -45,7 +46,8 @@ class XYBaseDataModule(LightningDataModule):
|
|
|
45
46
|
Attributes:
|
|
46
47
|
READER (DataReader): The data reader class to use.
|
|
47
48
|
reader (DataReader): An instance of the data reader class.
|
|
48
|
-
|
|
49
|
+
test_split (float): The ratio of test data to total data.
|
|
50
|
+
validation_split (float): The ratio of validation data to total data.
|
|
49
51
|
batch_size (int): The batch size for data loading.
|
|
50
52
|
prediction_kind (str): The kind of prediction to be performed.
|
|
51
53
|
data_limit (Optional[int]): The maximum number of data samples to load.
|
|
@@ -68,7 +70,8 @@ class XYBaseDataModule(LightningDataModule):
|
|
|
68
70
|
def __init__(
|
|
69
71
|
self,
|
|
70
72
|
batch_size: int = 1,
|
|
71
|
-
|
|
73
|
+
test_split: Optional[float] = 0.1,
|
|
74
|
+
validation_split: Optional[float] = 0.05,
|
|
72
75
|
reader_kwargs: Optional[dict] = None,
|
|
73
76
|
prediction_kind: str = "test",
|
|
74
77
|
data_limit: Optional[int] = None,
|
|
@@ -86,7 +89,9 @@ class XYBaseDataModule(LightningDataModule):
|
|
|
86
89
|
if reader_kwargs is None:
|
|
87
90
|
reader_kwargs = dict()
|
|
88
91
|
self.reader = self.READER(**reader_kwargs)
|
|
89
|
-
self.
|
|
92
|
+
self.test_split = test_split
|
|
93
|
+
self.validation_split = validation_split
|
|
94
|
+
|
|
90
95
|
self.batch_size = batch_size
|
|
91
96
|
self.prediction_kind = prediction_kind
|
|
92
97
|
self.data_limit = data_limit
|
|
@@ -335,8 +340,9 @@ class XYBaseDataModule(LightningDataModule):
|
|
|
335
340
|
val
|
|
336
341
|
for val in data
|
|
337
342
|
if val["features"] is not None
|
|
338
|
-
and
|
|
339
|
-
|
|
343
|
+
and (
|
|
344
|
+
self.n_token_limit is None or len(val["features"]) <= self.n_token_limit
|
|
345
|
+
)
|
|
340
346
|
]
|
|
341
347
|
|
|
342
348
|
return data
|
|
@@ -439,13 +445,25 @@ class XYBaseDataModule(LightningDataModule):
|
|
|
439
445
|
):
|
|
440
446
|
self.setup_processed()
|
|
441
447
|
|
|
442
|
-
|
|
443
|
-
|
|
448
|
+
self._after_setup(**kwargs)
|
|
449
|
+
|
|
450
|
+
def _after_setup(self, **kwargs):
|
|
451
|
+
"""
|
|
452
|
+
Finalize the setup process after ensuring the processed data is available.
|
|
444
453
|
|
|
454
|
+
This method performs post-setup tasks like finalizing the reader and setting internal properties.
|
|
455
|
+
"""
|
|
456
|
+
self.reader.on_finish()
|
|
445
457
|
self._set_processed_data_props()
|
|
446
458
|
|
|
447
459
|
def _set_processed_data_props(self):
|
|
460
|
+
"""
|
|
461
|
+
Load processed data and extract metadata.
|
|
448
462
|
|
|
463
|
+
Sets:
|
|
464
|
+
- self._num_of_labels: Number of target labels in the dataset.
|
|
465
|
+
- self._feature_vector_size: Maximum feature vector length across all data points.
|
|
466
|
+
"""
|
|
449
467
|
data_pt = torch.load(
|
|
450
468
|
os.path.join(self.processed_dir, self.processed_file_names_dict["data"]),
|
|
451
469
|
weights_only=False,
|
|
@@ -1009,15 +1027,13 @@ class _DynamicDataset(XYBaseDataModule, ABC):
|
|
|
1009
1027
|
|
|
1010
1028
|
labels_list = df["labels"].tolist()
|
|
1011
1029
|
|
|
1012
|
-
test_size = 1 - self.train_split - (1 - self.train_split) ** 2
|
|
1013
|
-
|
|
1014
1030
|
if len(labels_list[0]) > 1:
|
|
1015
1031
|
splitter = MultilabelStratifiedShuffleSplit(
|
|
1016
|
-
n_splits=1, test_size=
|
|
1032
|
+
n_splits=1, test_size=self.test_split, random_state=seed
|
|
1017
1033
|
)
|
|
1018
1034
|
else:
|
|
1019
1035
|
splitter = StratifiedShuffleSplit(
|
|
1020
|
-
n_splits=1, test_size=
|
|
1036
|
+
n_splits=1, test_size=self.test_split, random_state=seed
|
|
1021
1037
|
)
|
|
1022
1038
|
|
|
1023
1039
|
train_indices, test_indices = next(splitter.split(labels_list, labels_list))
|
|
@@ -1070,16 +1086,17 @@ class _DynamicDataset(XYBaseDataModule, ABC):
|
|
|
1070
1086
|
|
|
1071
1087
|
return folds
|
|
1072
1088
|
|
|
1073
|
-
# scale val set size by 1/self.train_split to compensate for (hypothetical) test set size (1-self.train_split)
|
|
1074
|
-
test_size = ((1 - self.train_split) ** 2) / self.train_split
|
|
1075
|
-
|
|
1076
1089
|
if len(labels_list_trainval[0]) > 1:
|
|
1077
1090
|
splitter = MultilabelStratifiedShuffleSplit(
|
|
1078
|
-
n_splits=1,
|
|
1091
|
+
n_splits=1,
|
|
1092
|
+
test_size=self.validation_split / (1 - self.test_split),
|
|
1093
|
+
random_state=seed,
|
|
1079
1094
|
)
|
|
1080
1095
|
else:
|
|
1081
1096
|
splitter = StratifiedShuffleSplit(
|
|
1082
|
-
n_splits=1,
|
|
1097
|
+
n_splits=1,
|
|
1098
|
+
test_size=self.validation_split / (1 - self.test_split),
|
|
1099
|
+
random_state=seed,
|
|
1083
1100
|
)
|
|
1084
1101
|
|
|
1085
1102
|
train_indices, validation_indices = next(
|
|
@@ -1102,7 +1119,9 @@ class _DynamicDataset(XYBaseDataModule, ABC):
|
|
|
1102
1119
|
splits_df = pd.read_csv(self.splits_file_path)
|
|
1103
1120
|
|
|
1104
1121
|
filename = self.processed_file_names_dict["data"]
|
|
1105
|
-
data = self.
|
|
1122
|
+
data = self.load_processed_data_from_file(
|
|
1123
|
+
os.path.join(self.processed_dir, filename)
|
|
1124
|
+
)
|
|
1106
1125
|
df_data = pd.DataFrame(data)
|
|
1107
1126
|
|
|
1108
1127
|
train_ids = splits_df[splits_df["split"] == "train"]["id"]
|
|
@@ -1113,6 +1132,7 @@ class _DynamicDataset(XYBaseDataModule, ABC):
|
|
|
1113
1132
|
self._dynamic_df_val = df_data[df_data["ident"].isin(validation_ids)]
|
|
1114
1133
|
self._dynamic_df_test = df_data[df_data["ident"].isin(test_ids)]
|
|
1115
1134
|
|
|
1135
|
+
# ------------------------------ Phase: DataLoaders -----------------------------------
|
|
1116
1136
|
def load_processed_data(
|
|
1117
1137
|
self, kind: Optional[str] = None, filename: Optional[str] = None
|
|
1118
1138
|
) -> List[Dict[str, Any]]:
|
|
@@ -1148,24 +1168,19 @@ class _DynamicDataset(XYBaseDataModule, ABC):
|
|
|
1148
1168
|
|
|
1149
1169
|
# If both kind and filename are given, use filename
|
|
1150
1170
|
if kind is not None and filename is None:
|
|
1151
|
-
|
|
1152
|
-
|
|
1153
|
-
|
|
1154
|
-
|
|
1155
|
-
|
|
1156
|
-
|
|
1157
|
-
|
|
1158
|
-
return data_df.to_dict(orient="records")
|
|
1159
|
-
except KeyError:
|
|
1160
|
-
kind = f"{kind}"
|
|
1171
|
+
if self.use_inner_cross_validation and kind != "test":
|
|
1172
|
+
filename = self.processed_file_names_dict[
|
|
1173
|
+
f"fold_{self.fold_index}_{kind}"
|
|
1174
|
+
]
|
|
1175
|
+
else:
|
|
1176
|
+
data_df = self.dynamic_split_dfs[kind]
|
|
1177
|
+
return data_df.to_dict(orient="records")
|
|
1161
1178
|
|
|
1162
1179
|
# If filename is provided
|
|
1163
|
-
|
|
1164
|
-
|
|
1165
|
-
|
|
1166
|
-
|
|
1167
|
-
except FileNotFoundError:
|
|
1168
|
-
raise FileNotFoundError(f"File {filename} doesn't exist")
|
|
1180
|
+
return self.load_processed_data_from_file(filename)
|
|
1181
|
+
|
|
1182
|
+
def load_processed_data_from_file(self, filename):
|
|
1183
|
+
return torch.load(os.path.join(filename), weights_only=False)
|
|
1169
1184
|
|
|
1170
1185
|
# ------------------------------ Phase: Raw Properties -----------------------------------
|
|
1171
1186
|
@property
|
|
@@ -401,8 +401,8 @@ class _ChEBIDataExtractor(_DynamicDataset, ABC):
|
|
|
401
401
|
"""
|
|
402
402
|
try:
|
|
403
403
|
filename = self.processed_file_names_dict["data"]
|
|
404
|
-
data_chebi_version =
|
|
405
|
-
os.path.join(self.processed_dir, filename)
|
|
404
|
+
data_chebi_version = self.load_processed_data_from_file(
|
|
405
|
+
os.path.join(self.processed_dir, filename)
|
|
406
406
|
)
|
|
407
407
|
except FileNotFoundError:
|
|
408
408
|
raise FileNotFoundError(
|
|
@@ -1,5 +1,9 @@
|
|
|
1
|
+
import inspect
|
|
1
2
|
import os
|
|
2
|
-
|
|
3
|
+
import sys
|
|
4
|
+
from abc import ABC
|
|
5
|
+
from itertools import islice
|
|
6
|
+
from typing import Any, Dict, List, Optional
|
|
3
7
|
|
|
4
8
|
import deepsmiles
|
|
5
9
|
import selfies as sf
|
|
@@ -36,7 +40,7 @@ class DataReader:
|
|
|
36
40
|
if collator_kwargs is None:
|
|
37
41
|
collator_kwargs = dict()
|
|
38
42
|
self.collator = self.COLLATOR(**collator_kwargs)
|
|
39
|
-
self.dirname = os.path.dirname(
|
|
43
|
+
self.dirname = os.path.dirname(inspect.getfile(self.__class__))
|
|
40
44
|
self._token_path = token_path
|
|
41
45
|
|
|
42
46
|
def _get_raw_data(self, row: Dict[str, Any]) -> Any:
|
|
@@ -117,33 +121,65 @@ class DataReader:
|
|
|
117
121
|
return
|
|
118
122
|
|
|
119
123
|
|
|
120
|
-
class
|
|
124
|
+
class TokenIndexerReader(DataReader, ABC):
|
|
121
125
|
"""
|
|
122
|
-
|
|
126
|
+
Abstract base class for reading tokenized data and mapping tokens to unique indices.
|
|
123
127
|
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
token_path: Optional path for the token file.
|
|
127
|
-
kwargs: Additional keyword arguments.
|
|
128
|
+
This class maintains a cache of token-to-index mappings that can be extended during runtime,
|
|
129
|
+
and saves new tokens to a persistent file at the end of processing.
|
|
128
130
|
"""
|
|
129
131
|
|
|
130
|
-
COLLATOR = RaggedCollator
|
|
131
|
-
|
|
132
|
-
@classmethod
|
|
133
|
-
def name(cls) -> str:
|
|
134
|
-
"""Returns the name of the data reader."""
|
|
135
|
-
return "smiles_token"
|
|
136
|
-
|
|
137
132
|
def __init__(self, *args, **kwargs):
|
|
138
133
|
super().__init__(*args, **kwargs)
|
|
139
134
|
with open(self.token_path, "r") as pk:
|
|
140
|
-
self.cache
|
|
135
|
+
self.cache: Dict[str, int] = {
|
|
136
|
+
token.strip(): idx for idx, token in enumerate(pk)
|
|
137
|
+
}
|
|
138
|
+
self._loaded_tokens_count = len(self.cache)
|
|
141
139
|
|
|
142
140
|
def _get_token_index(self, token: str) -> int:
|
|
143
141
|
"""Returns a unique number for each token, automatically adds new tokens."""
|
|
144
142
|
if not str(token) in self.cache:
|
|
145
|
-
self.cache
|
|
146
|
-
return self.cache
|
|
143
|
+
self.cache[(str(token))] = len(self.cache)
|
|
144
|
+
return self.cache[str(token)] + EMBEDDING_OFFSET
|
|
145
|
+
|
|
146
|
+
def on_finish(self) -> None:
|
|
147
|
+
"""
|
|
148
|
+
Saves the current cache of tokens to the token file.This method is called after all data processing is complete.
|
|
149
|
+
"""
|
|
150
|
+
print(f"first 10 tokens: {list(islice(self.cache, 10))}")
|
|
151
|
+
|
|
152
|
+
total_tokens = len(self.cache)
|
|
153
|
+
if total_tokens > self._loaded_tokens_count:
|
|
154
|
+
print("New tokens added to the cache, Saving them to token file.....")
|
|
155
|
+
|
|
156
|
+
assert sys.version_info >= (
|
|
157
|
+
3,
|
|
158
|
+
7,
|
|
159
|
+
), "This code requires Python 3.7 or higher."
|
|
160
|
+
# For python 3.7+, the standard dict type preserves insertion order, and is iterated over in same order
|
|
161
|
+
# https://docs.python.org/3/whatsnew/3.7.html#summary-release-highlights
|
|
162
|
+
# https://mail.python.org/pipermail/python-dev/2017-December/151283.html
|
|
163
|
+
new_tokens = list(
|
|
164
|
+
islice(self.cache, self._loaded_tokens_count, total_tokens)
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
with open(self.token_path, "a") as pk:
|
|
168
|
+
print(f"saving new {len(new_tokens)} tokens to {self.token_path}...")
|
|
169
|
+
pk.writelines([f"{c}\n" for c in new_tokens])
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
class ChemDataReader(TokenIndexerReader):
|
|
173
|
+
"""
|
|
174
|
+
Data reader for chemical data using SMILES tokens.
|
|
175
|
+
"""
|
|
176
|
+
|
|
177
|
+
COLLATOR = RaggedCollator
|
|
178
|
+
|
|
179
|
+
@classmethod
|
|
180
|
+
def name(cls) -> str:
|
|
181
|
+
"""Returns the name of the data reader."""
|
|
182
|
+
return "smiles_token"
|
|
147
183
|
|
|
148
184
|
def _read_data(self, raw_data: str) -> List[int]:
|
|
149
185
|
"""
|
|
@@ -157,15 +193,6 @@ class ChemDataReader(DataReader):
|
|
|
157
193
|
"""
|
|
158
194
|
return [self._get_token_index(v[1]) for v in _tokenize(raw_data)]
|
|
159
195
|
|
|
160
|
-
def on_finish(self) -> None:
|
|
161
|
-
"""
|
|
162
|
-
Saves the current cache of tokens to the token file. This method is called after all data processing is complete.
|
|
163
|
-
"""
|
|
164
|
-
with open(self.token_path, "w") as pk:
|
|
165
|
-
print(f"saving {len(self.cache)} tokens to {self.token_path}...")
|
|
166
|
-
print(f"first 10 tokens: {self.cache[:10]}")
|
|
167
|
-
pk.writelines([f"{c}\n" for c in self.cache])
|
|
168
|
-
|
|
169
196
|
|
|
170
197
|
class DeepChemDataReader(ChemDataReader):
|
|
171
198
|
"""
|
|
@@ -1,20 +1,21 @@
|
|
|
1
1
|
import gc
|
|
2
|
-
import sys
|
|
3
2
|
import traceback
|
|
4
3
|
from datetime import datetime
|
|
5
4
|
from typing import List, LiteralString
|
|
6
5
|
|
|
6
|
+
import pandas as pd
|
|
7
7
|
from torchmetrics.functional.classification import (
|
|
8
8
|
multilabel_auroc,
|
|
9
9
|
multilabel_average_precision,
|
|
10
10
|
multilabel_f1_score,
|
|
11
11
|
)
|
|
12
|
-
from utils import *
|
|
13
12
|
|
|
14
13
|
from chebai.loss.semantic import DisjointLoss
|
|
14
|
+
from chebai.models import Electra
|
|
15
15
|
from chebai.preprocessing.datasets.base import _DynamicDataset
|
|
16
16
|
from chebai.preprocessing.datasets.chebi import ChEBIOver100
|
|
17
17
|
from chebai.preprocessing.datasets.pubchem import PubChemKMeans
|
|
18
|
+
from chebai.result.utils import *
|
|
18
19
|
|
|
19
20
|
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
20
21
|
|
|
@@ -122,7 +123,7 @@ def load_preds_labels(
|
|
|
122
123
|
def get_label_names(data_module):
|
|
123
124
|
if os.path.exists(os.path.join(data_module.processed_dir_main, "classes.txt")):
|
|
124
125
|
with open(os.path.join(data_module.processed_dir_main, "classes.txt")) as fin:
|
|
125
|
-
return [
|
|
126
|
+
return [line.strip() for line in fin]
|
|
126
127
|
print(
|
|
127
128
|
f"Failed to retrieve label names, {os.path.join(data_module.processed_dir_main, 'classes.txt')} not found"
|
|
128
129
|
)
|
|
@@ -131,70 +132,97 @@ def get_label_names(data_module):
|
|
|
131
132
|
|
|
132
133
|
def get_chebi_graph(data_module, label_names):
|
|
133
134
|
if os.path.exists(os.path.join(data_module.raw_dir, "chebi.obo")):
|
|
134
|
-
chebi_graph = data_module.
|
|
135
|
+
chebi_graph = data_module._extract_class_hierarchy(
|
|
135
136
|
os.path.join(data_module.raw_dir, "chebi.obo")
|
|
136
137
|
)
|
|
137
|
-
return chebi_graph.subgraph(label_names)
|
|
138
|
+
return chebi_graph.subgraph([int(n) for n in label_names])
|
|
138
139
|
print(
|
|
139
140
|
f"Failed to retrieve ChEBI graph, {os.path.join(data_module.raw_dir, 'chebi.obo')} not found"
|
|
140
141
|
)
|
|
141
142
|
return None
|
|
142
143
|
|
|
143
144
|
|
|
144
|
-
def get_disjoint_groups():
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
145
|
+
def get_disjoint_groups(disjoint_files):
|
|
146
|
+
if disjoint_files is None:
|
|
147
|
+
disjoint_files = os.path.join("data", "chebi-disjoints.owl")
|
|
148
|
+
disjoint_pairs, disjoint_groups = [], []
|
|
149
|
+
for file in disjoint_files:
|
|
150
|
+
if file.split(".")[-1] == "csv":
|
|
151
|
+
disjoint_pairs += pd.read_csv(file, header=None).values.tolist()
|
|
152
|
+
elif file.split(".")[-1] == "owl":
|
|
153
|
+
with open(file, "r") as f:
|
|
154
|
+
plaintext = f.read()
|
|
155
|
+
segments = plaintext.split("<")
|
|
156
|
+
disjoint_pairs = []
|
|
157
|
+
left = None
|
|
158
|
+
for seg in segments:
|
|
159
|
+
if seg.startswith("rdf:Description ") or seg.startswith(
|
|
160
|
+
"owl:Class"
|
|
161
|
+
):
|
|
162
|
+
left = int(seg.split('rdf:about="&obo;CHEBI_')[1].split('"')[0])
|
|
163
|
+
elif seg.startswith("owl:disjointWith"):
|
|
164
|
+
right = int(
|
|
165
|
+
seg.split('rdf:resource="&obo;CHEBI_')[1].split('"')[0]
|
|
166
|
+
)
|
|
167
|
+
disjoint_pairs.append([left, right])
|
|
168
|
+
|
|
169
|
+
disjoint_groups = []
|
|
170
|
+
for seg in plaintext.split("<rdf:Description>"):
|
|
171
|
+
if "owl;AllDisjointClasses" in seg:
|
|
172
|
+
classes = seg.split('rdf:about="&obo;CHEBI_')[1:]
|
|
173
|
+
classes = [int(c.split('"')[0]) for c in classes]
|
|
174
|
+
disjoint_groups.append(classes)
|
|
175
|
+
else:
|
|
176
|
+
raise NotImplementedError(
|
|
177
|
+
"Unsupported disjoint file format: " + file.split(".")[-1]
|
|
178
|
+
)
|
|
179
|
+
|
|
164
180
|
disjoint_all = disjoint_pairs + disjoint_groups
|
|
165
181
|
# one disjointness is commented out in the owl-file
|
|
166
182
|
# (the correct way would be to parse the owl file and notice the comment symbols, but for this case, it should work)
|
|
167
|
-
|
|
168
|
-
|
|
183
|
+
if [22729, 51880] in disjoint_all:
|
|
184
|
+
disjoint_all.remove([22729, 51880])
|
|
185
|
+
# print(f"Found {len(disjoint_all)} disjoint groups")
|
|
169
186
|
return disjoint_all
|
|
170
187
|
|
|
171
188
|
|
|
172
189
|
class PredictionSmoother:
|
|
173
190
|
"""Removes implication and disjointness violations from predictions"""
|
|
174
191
|
|
|
175
|
-
def __init__(self, dataset):
|
|
176
|
-
|
|
192
|
+
def __init__(self, dataset, label_names=None, disjoint_files=None):
|
|
193
|
+
if label_names:
|
|
194
|
+
self.label_names = label_names
|
|
195
|
+
else:
|
|
196
|
+
self.label_names = get_label_names(dataset)
|
|
177
197
|
self.chebi_graph = get_chebi_graph(dataset, self.label_names)
|
|
178
|
-
self.disjoint_groups = get_disjoint_groups()
|
|
198
|
+
self.disjoint_groups = get_disjoint_groups(disjoint_files)
|
|
179
199
|
|
|
180
200
|
def __call__(self, preds):
|
|
181
|
-
|
|
182
201
|
preds_sum_orig = torch.sum(preds)
|
|
183
|
-
print(f"Preds sum: {preds_sum_orig}")
|
|
184
|
-
# eliminate implication violations by setting each prediction to maximum of its successors
|
|
185
202
|
for i, label in enumerate(self.label_names):
|
|
186
203
|
succs = [
|
|
187
|
-
self.label_names.index(p)
|
|
204
|
+
self.label_names.index(str(p))
|
|
205
|
+
for p in self.chebi_graph.successors(int(label))
|
|
188
206
|
] + [i]
|
|
189
207
|
if len(succs) > 0:
|
|
208
|
+
if torch.max(preds[:, succs], dim=1).values > 0.5 and preds[:, i] < 0.5:
|
|
209
|
+
print(
|
|
210
|
+
f"Correcting prediction for {label} to max of subclasses {list(self.chebi_graph.successors(int(label)))}"
|
|
211
|
+
)
|
|
212
|
+
print(
|
|
213
|
+
f"Original pred: {preds[:, i]}, successors: {preds[:, succs]}"
|
|
214
|
+
)
|
|
190
215
|
preds[:, i] = torch.max(preds[:, succs], dim=1).values
|
|
191
|
-
|
|
216
|
+
if torch.sum(preds) != preds_sum_orig:
|
|
217
|
+
print(f"Preds change (step 1): {torch.sum(preds) - preds_sum_orig}")
|
|
192
218
|
preds_sum_orig = torch.sum(preds)
|
|
193
219
|
# step 2: eliminate disjointness violations: for group of disjoint classes, set all except max to 0.49 (if it is not already lower)
|
|
194
220
|
preds_bounded = torch.min(preds, torch.ones_like(preds) * 0.49)
|
|
195
221
|
for disj_group in self.disjoint_groups:
|
|
196
222
|
disj_group = [
|
|
197
|
-
self.label_names.index(g)
|
|
223
|
+
self.label_names.index(str(g))
|
|
224
|
+
for g in disj_group
|
|
225
|
+
if g in self.label_names
|
|
198
226
|
]
|
|
199
227
|
if len(disj_group) > 1:
|
|
200
228
|
old_preds = preds[:, disj_group]
|
|
@@ -211,14 +239,12 @@ class PredictionSmoother:
|
|
|
211
239
|
print(
|
|
212
240
|
f"disjointness group {[self.label_names[d] for d in disj_group]} changed {samples_changed} samples"
|
|
213
241
|
)
|
|
214
|
-
print(
|
|
215
|
-
f"Preds change after disjointness (step 2): {torch.sum(preds) - preds_sum_orig}"
|
|
216
|
-
)
|
|
217
242
|
preds_sum_orig = torch.sum(preds)
|
|
218
243
|
# step 3: disjointness violation removal may have caused new implication inconsistencies -> set each prediction to min of predecessors
|
|
219
244
|
for i, label in enumerate(self.label_names):
|
|
220
245
|
predecessors = [i] + [
|
|
221
|
-
self.label_names.index(p)
|
|
246
|
+
self.label_names.index(str(p))
|
|
247
|
+
for p in self.chebi_graph.predecessors(int(label))
|
|
222
248
|
]
|
|
223
249
|
lowest_predecessors = torch.min(preds[:, predecessors], dim=1)
|
|
224
250
|
preds[:, i] = lowest_predecessors.values
|