chebai 0.0.2.dev0__tar.gz → 1.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. {chebai-0.0.2.dev0 → chebai-1.0.1}/PKG-INFO +2 -2
  2. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/models/base.py +5 -4
  3. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/preprocessing/datasets/base.py +49 -34
  4. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/preprocessing/datasets/chebi.py +2 -2
  5. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/preprocessing/reader.py +54 -27
  6. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/result/analyse_sem.py +66 -40
  7. chebai-1.0.1/chebai/result/generate_class_properties.py +200 -0
  8. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/result/utils.py +83 -5
  9. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai.egg-info/PKG-INFO +2 -2
  10. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai.egg-info/SOURCES.txt +1 -1
  11. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai.egg-info/requires.txt +1 -1
  12. {chebai-0.0.2.dev0 → chebai-1.0.1}/setup.py +2 -2
  13. {chebai-0.0.2.dev0 → chebai-1.0.1}/tests/unit/readers/testChemDataReader.py +62 -9
  14. {chebai-0.0.2.dev0 → chebai-1.0.1}/tests/unit/readers/testDeepChemDataReader.py +9 -7
  15. {chebai-0.0.2.dev0 → chebai-1.0.1}/tests/unit/readers/testSelfiesReader.py +8 -6
  16. chebai-0.0.2.dev0/chebai/preprocessing/bin/graph_properties/tokens.txt +0 -0
  17. {chebai-0.0.2.dev0 → chebai-1.0.1}/LICENSE +0 -0
  18. {chebai-0.0.2.dev0 → chebai-1.0.1}/README.md +0 -0
  19. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/__init__.py +0 -0
  20. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/__main__.py +0 -0
  21. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/callbacks/__init__.py +0 -0
  22. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/callbacks/epoch_metrics.py +0 -0
  23. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/callbacks/model_checkpoint.py +0 -0
  24. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/callbacks/prediction_callback.py +0 -0
  25. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/callbacks.py +0 -0
  26. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/cli.py +0 -0
  27. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/loggers/__init__.py +0 -0
  28. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/loggers/custom.py +0 -0
  29. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/loss/__init__.py +0 -0
  30. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/loss/bce_weighted.py +0 -0
  31. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/loss/mixed.py +0 -0
  32. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/loss/pretraining.py +0 -0
  33. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/loss/semantic.py +0 -0
  34. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/models/__init__.py +0 -0
  35. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/models/chemberta.py +0 -0
  36. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/models/chemyk.py +0 -0
  37. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/models/electra.py +0 -0
  38. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/models/external/__init__.py +0 -0
  39. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/models/ffn.py +0 -0
  40. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/models/lnn_model.py +0 -0
  41. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/models/lstm.py +0 -0
  42. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/models/recursive.py +0 -0
  43. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/models/strontex.py +0 -0
  44. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/molecule.py +0 -0
  45. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/preprocessing/__init__.py +0 -0
  46. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/preprocessing/bin/BPE_SWJ/merges.txt +0 -0
  47. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/preprocessing/bin/BPE_SWJ/vocab.json +0 -0
  48. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/preprocessing/bin/deepsmiles_token/tokens.txt +0 -0
  49. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/preprocessing/bin/graph/tokens.txt +0 -0
  50. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/preprocessing/bin/selfies/tokens.txt +0 -0
  51. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/preprocessing/bin/smiles_token/tokens.txt +0 -0
  52. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/preprocessing/bin/smiles_token_unlabeled/tokens.txt +0 -0
  53. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/preprocessing/collate.py +0 -0
  54. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/preprocessing/collect_all.py +0 -0
  55. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/preprocessing/datasets/__init__.py +0 -0
  56. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/preprocessing/datasets/pubchem.py +0 -0
  57. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/preprocessing/datasets/tox21.py +0 -0
  58. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/preprocessing/migration/__init__.py +0 -0
  59. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/preprocessing/migration/chebi_data_migration.py +0 -0
  60. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/preprocessing/structures.py +0 -0
  61. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/result/__init__.py +0 -0
  62. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/result/base.py +0 -0
  63. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/result/classification.py +0 -0
  64. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/result/evaluate_predictions.py +0 -0
  65. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/result/molplot.py +0 -0
  66. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/result/prediction_json.py +0 -0
  67. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/result/pretraining.py +0 -0
  68. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/train.py +0 -0
  69. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/trainer/CustomTrainer.py +0 -0
  70. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai/trainer/__init__.py +0 -0
  71. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai.egg-info/dependency_links.txt +0 -0
  72. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai.egg-info/not-zip-safe +0 -0
  73. {chebai-0.0.2.dev0 → chebai-1.0.1}/chebai.egg-info/top_level.txt +0 -0
  74. {chebai-0.0.2.dev0 → chebai-1.0.1}/setup.cfg +0 -0
  75. {chebai-0.0.2.dev0 → chebai-1.0.1}/tests/__init__.py +0 -0
  76. {chebai-0.0.2.dev0 → chebai-1.0.1}/tests/integration/__init__.py +0 -0
  77. {chebai-0.0.2.dev0 → chebai-1.0.1}/tests/integration/testChebiData.py +0 -0
  78. {chebai-0.0.2.dev0 → chebai-1.0.1}/tests/integration/testChebiDynamicDataSplits.py +0 -0
  79. {chebai-0.0.2.dev0 → chebai-1.0.1}/tests/integration/testCustomBalancedAccuracyMetric.py +0 -0
  80. {chebai-0.0.2.dev0 → chebai-1.0.1}/tests/integration/testCustomMacroF1Metric.py +0 -0
  81. {chebai-0.0.2.dev0 → chebai-1.0.1}/tests/integration/testPubChemData.py +0 -0
  82. {chebai-0.0.2.dev0 → chebai-1.0.1}/tests/integration/testTox21MolNetData.py +0 -0
  83. {chebai-0.0.2.dev0 → chebai-1.0.1}/tests/unit/__init__.py +0 -0
  84. {chebai-0.0.2.dev0 → chebai-1.0.1}/tests/unit/collators/__init__.py +0 -0
  85. {chebai-0.0.2.dev0 → chebai-1.0.1}/tests/unit/collators/testDefaultCollator.py +0 -0
  86. {chebai-0.0.2.dev0 → chebai-1.0.1}/tests/unit/collators/testRaggedCollator.py +0 -0
  87. {chebai-0.0.2.dev0 → chebai-1.0.1}/tests/unit/dataset_classes/__init__.py +0 -0
  88. {chebai-0.0.2.dev0 → chebai-1.0.1}/tests/unit/dataset_classes/testChEBIOverX.py +0 -0
  89. {chebai-0.0.2.dev0 → chebai-1.0.1}/tests/unit/dataset_classes/testChebiDataExtractor.py +0 -0
  90. {chebai-0.0.2.dev0 → chebai-1.0.1}/tests/unit/dataset_classes/testChebiOverXPartial.py +0 -0
  91. {chebai-0.0.2.dev0 → chebai-1.0.1}/tests/unit/dataset_classes/testChebiTermCallback.py +0 -0
  92. {chebai-0.0.2.dev0 → chebai-1.0.1}/tests/unit/dataset_classes/testDynamicDataset.py +0 -0
  93. {chebai-0.0.2.dev0 → chebai-1.0.1}/tests/unit/dataset_classes/testTox21Challenge.py +0 -0
  94. {chebai-0.0.2.dev0 → chebai-1.0.1}/tests/unit/dataset_classes/testXYBaseDataModule.py +0 -0
  95. {chebai-0.0.2.dev0 → chebai-1.0.1}/tests/unit/mock_data/__init__.py +0 -0
  96. {chebai-0.0.2.dev0 → chebai-1.0.1}/tests/unit/mock_data/ontology_mock_data.py +0 -0
  97. {chebai-0.0.2.dev0 → chebai-1.0.1}/tests/unit/mock_data/tox_mock_data.py +0 -0
  98. {chebai-0.0.2.dev0 → chebai-1.0.1}/tests/unit/readers/__init__.py +0 -0
  99. {chebai-0.0.2.dev0 → chebai-1.0.1}/tests/unit/readers/testDataReader.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: chebai
3
- Version: 0.0.2.dev0
3
+ Version: 1.0.1
4
4
  Home-page:
5
5
  Author: MGlauer
6
6
  Author-email: martin.glauer@ovgu.de
@@ -10,7 +10,7 @@ Requires-Dist: certifi
10
10
  Requires-Dist: idna
11
11
  Requires-Dist: joblib
12
12
  Requires-Dist: networkx
13
- Requires-Dist: numpy<2
13
+ Requires-Dist: numpy
14
14
  Requires-Dist: pandas
15
15
  Requires-Dist: python-dateutil
16
16
  Requires-Dist: pytz
@@ -1,9 +1,9 @@
1
1
  import logging
2
- from typing import Any, Dict, Optional, Union, Iterable
2
+ from abc import ABC, abstractmethod
3
+ from typing import Any, Dict, Iterable, Optional, Union
3
4
 
4
5
  import torch
5
6
  from lightning.pytorch.core.module import LightningModule
6
- from torchmetrics import Metric
7
7
 
8
8
  from chebai.preprocessing.structures import XYData
9
9
 
@@ -12,7 +12,7 @@ logging.getLogger("pysmiles").setLevel(logging.CRITICAL)
12
12
  _MODEL_REGISTRY = dict()
13
13
 
14
14
 
15
- class ChebaiBaseNet(LightningModule):
15
+ class ChebaiBaseNet(LightningModule, ABC):
16
16
  """
17
17
  Base class for Chebai neural network models inheriting from PyTorch Lightning's LightningModule.
18
18
 
@@ -353,6 +353,7 @@ class ChebaiBaseNet(LightningModule):
353
353
  logger=True,
354
354
  )
355
355
 
356
+ @abstractmethod
356
357
  def forward(self, x: Dict[str, Any]) -> torch.Tensor:
357
358
  """
358
359
  Defines the forward pass.
@@ -363,7 +364,7 @@ class ChebaiBaseNet(LightningModule):
363
364
  Returns:
364
365
  torch.Tensor: The model output.
365
366
  """
366
- raise NotImplementedError
367
+ pass
367
368
 
368
369
  def configure_optimizers(self, **kwargs) -> torch.optim.Optimizer:
369
370
  """
@@ -29,7 +29,8 @@ class XYBaseDataModule(LightningDataModule):
29
29
 
30
30
  Args:
31
31
  batch_size (int): The batch size for data loading. Default is 1.
32
- train_split (float): The ratio of training data to total data and of test data to (validation + test) data. Default is 0.85.
32
+ test_split (float): The ratio of test data to total data. Default is 0.1.
33
+ validation_split (float): The ratio of validation data to total data. Default is 0.05.
33
34
  reader_kwargs (dict): Additional keyword arguments to be passed to the data reader. Default is None.
34
35
  prediction_kind (str): The kind of prediction to be performed (only relevant for the predict_dataloader). Default is "test".
35
36
  data_limit (Optional[int]): The maximum number of data samples to load. If set to None, the complete dataset will be used. Default is None.
@@ -45,7 +46,8 @@ class XYBaseDataModule(LightningDataModule):
45
46
  Attributes:
46
47
  READER (DataReader): The data reader class to use.
47
48
  reader (DataReader): An instance of the data reader class.
48
- train_split (float): The ratio of training data to total data.
49
+ test_split (float): The ratio of test data to total data.
50
+ validation_split (float): The ratio of validation data to total data.
49
51
  batch_size (int): The batch size for data loading.
50
52
  prediction_kind (str): The kind of prediction to be performed.
51
53
  data_limit (Optional[int]): The maximum number of data samples to load.
@@ -68,7 +70,8 @@ class XYBaseDataModule(LightningDataModule):
68
70
  def __init__(
69
71
  self,
70
72
  batch_size: int = 1,
71
- train_split: float = 0.85,
73
+ test_split: Optional[float] = 0.1,
74
+ validation_split: Optional[float] = 0.05,
72
75
  reader_kwargs: Optional[dict] = None,
73
76
  prediction_kind: str = "test",
74
77
  data_limit: Optional[int] = None,
@@ -86,7 +89,9 @@ class XYBaseDataModule(LightningDataModule):
86
89
  if reader_kwargs is None:
87
90
  reader_kwargs = dict()
88
91
  self.reader = self.READER(**reader_kwargs)
89
- self.train_split = train_split
92
+ self.test_split = test_split
93
+ self.validation_split = validation_split
94
+
90
95
  self.batch_size = batch_size
91
96
  self.prediction_kind = prediction_kind
92
97
  self.data_limit = data_limit
@@ -335,8 +340,9 @@ class XYBaseDataModule(LightningDataModule):
335
340
  val
336
341
  for val in data
337
342
  if val["features"] is not None
338
- and self.n_token_limit is None
339
- or len(val["features"]) <= self.n_token_limit
343
+ and (
344
+ self.n_token_limit is None or len(val["features"]) <= self.n_token_limit
345
+ )
340
346
  ]
341
347
 
342
348
  return data
@@ -439,13 +445,25 @@ class XYBaseDataModule(LightningDataModule):
439
445
  ):
440
446
  self.setup_processed()
441
447
 
442
- if not ("keep_reader" in kwargs and kwargs["keep_reader"]):
443
- self.reader.on_finish()
448
+ self._after_setup(**kwargs)
449
+
450
+ def _after_setup(self, **kwargs):
451
+ """
452
+ Finalize the setup process after ensuring the processed data is available.
444
453
 
454
+ This method performs post-setup tasks like finalizing the reader and setting internal properties.
455
+ """
456
+ self.reader.on_finish()
445
457
  self._set_processed_data_props()
446
458
 
447
459
  def _set_processed_data_props(self):
460
+ """
461
+ Load processed data and extract metadata.
448
462
 
463
+ Sets:
464
+ - self._num_of_labels: Number of target labels in the dataset.
465
+ - self._feature_vector_size: Maximum feature vector length across all data points.
466
+ """
449
467
  data_pt = torch.load(
450
468
  os.path.join(self.processed_dir, self.processed_file_names_dict["data"]),
451
469
  weights_only=False,
@@ -1009,15 +1027,13 @@ class _DynamicDataset(XYBaseDataModule, ABC):
1009
1027
 
1010
1028
  labels_list = df["labels"].tolist()
1011
1029
 
1012
- test_size = 1 - self.train_split - (1 - self.train_split) ** 2
1013
-
1014
1030
  if len(labels_list[0]) > 1:
1015
1031
  splitter = MultilabelStratifiedShuffleSplit(
1016
- n_splits=1, test_size=test_size, random_state=seed
1032
+ n_splits=1, test_size=self.test_split, random_state=seed
1017
1033
  )
1018
1034
  else:
1019
1035
  splitter = StratifiedShuffleSplit(
1020
- n_splits=1, test_size=test_size, random_state=seed
1036
+ n_splits=1, test_size=self.test_split, random_state=seed
1021
1037
  )
1022
1038
 
1023
1039
  train_indices, test_indices = next(splitter.split(labels_list, labels_list))
@@ -1070,16 +1086,17 @@ class _DynamicDataset(XYBaseDataModule, ABC):
1070
1086
 
1071
1087
  return folds
1072
1088
 
1073
- # scale val set size by 1/self.train_split to compensate for (hypothetical) test set size (1-self.train_split)
1074
- test_size = ((1 - self.train_split) ** 2) / self.train_split
1075
-
1076
1089
  if len(labels_list_trainval[0]) > 1:
1077
1090
  splitter = MultilabelStratifiedShuffleSplit(
1078
- n_splits=1, test_size=test_size, random_state=seed
1091
+ n_splits=1,
1092
+ test_size=self.validation_split / (1 - self.test_split),
1093
+ random_state=seed,
1079
1094
  )
1080
1095
  else:
1081
1096
  splitter = StratifiedShuffleSplit(
1082
- n_splits=1, test_size=test_size, random_state=seed
1097
+ n_splits=1,
1098
+ test_size=self.validation_split / (1 - self.test_split),
1099
+ random_state=seed,
1083
1100
  )
1084
1101
 
1085
1102
  train_indices, validation_indices = next(
@@ -1102,7 +1119,9 @@ class _DynamicDataset(XYBaseDataModule, ABC):
1102
1119
  splits_df = pd.read_csv(self.splits_file_path)
1103
1120
 
1104
1121
  filename = self.processed_file_names_dict["data"]
1105
- data = self.load_processed_data(filename=filename)
1122
+ data = self.load_processed_data_from_file(
1123
+ os.path.join(self.processed_dir, filename)
1124
+ )
1106
1125
  df_data = pd.DataFrame(data)
1107
1126
 
1108
1127
  train_ids = splits_df[splits_df["split"] == "train"]["id"]
@@ -1113,6 +1132,7 @@ class _DynamicDataset(XYBaseDataModule, ABC):
1113
1132
  self._dynamic_df_val = df_data[df_data["ident"].isin(validation_ids)]
1114
1133
  self._dynamic_df_test = df_data[df_data["ident"].isin(test_ids)]
1115
1134
 
1135
+ # ------------------------------ Phase: DataLoaders -----------------------------------
1116
1136
  def load_processed_data(
1117
1137
  self, kind: Optional[str] = None, filename: Optional[str] = None
1118
1138
  ) -> List[Dict[str, Any]]:
@@ -1148,24 +1168,19 @@ class _DynamicDataset(XYBaseDataModule, ABC):
1148
1168
 
1149
1169
  # If both kind and filename are given, use filename
1150
1170
  if kind is not None and filename is None:
1151
- try:
1152
- if self.use_inner_cross_validation and kind != "test":
1153
- filename = self.processed_file_names_dict[
1154
- f"fold_{self.fold_index}_{kind}"
1155
- ]
1156
- else:
1157
- data_df = self.dynamic_split_dfs[kind]
1158
- return data_df.to_dict(orient="records")
1159
- except KeyError:
1160
- kind = f"{kind}"
1171
+ if self.use_inner_cross_validation and kind != "test":
1172
+ filename = self.processed_file_names_dict[
1173
+ f"fold_{self.fold_index}_{kind}"
1174
+ ]
1175
+ else:
1176
+ data_df = self.dynamic_split_dfs[kind]
1177
+ return data_df.to_dict(orient="records")
1161
1178
 
1162
1179
  # If filename is provided
1163
- try:
1164
- return torch.load(
1165
- os.path.join(self.processed_dir, filename), weights_only=False
1166
- )
1167
- except FileNotFoundError:
1168
- raise FileNotFoundError(f"File {filename} doesn't exist")
1180
+ return self.load_processed_data_from_file(filename)
1181
+
1182
+ def load_processed_data_from_file(self, filename):
1183
+ return torch.load(os.path.join(filename), weights_only=False)
1169
1184
 
1170
1185
  # ------------------------------ Phase: Raw Properties -----------------------------------
1171
1186
  @property
@@ -401,8 +401,8 @@ class _ChEBIDataExtractor(_DynamicDataset, ABC):
401
401
  """
402
402
  try:
403
403
  filename = self.processed_file_names_dict["data"]
404
- data_chebi_version = torch.load(
405
- os.path.join(self.processed_dir, filename), weights_only=False
404
+ data_chebi_version = self.load_processed_data_from_file(
405
+ os.path.join(self.processed_dir, filename)
406
406
  )
407
407
  except FileNotFoundError:
408
408
  raise FileNotFoundError(
@@ -1,5 +1,9 @@
1
+ import inspect
1
2
  import os
2
- from typing import Any, Dict, List, Optional, Tuple
3
+ import sys
4
+ from abc import ABC
5
+ from itertools import islice
6
+ from typing import Any, Dict, List, Optional
3
7
 
4
8
  import deepsmiles
5
9
  import selfies as sf
@@ -36,7 +40,7 @@ class DataReader:
36
40
  if collator_kwargs is None:
37
41
  collator_kwargs = dict()
38
42
  self.collator = self.COLLATOR(**collator_kwargs)
39
- self.dirname = os.path.dirname(__file__)
43
+ self.dirname = os.path.dirname(inspect.getfile(self.__class__))
40
44
  self._token_path = token_path
41
45
 
42
46
  def _get_raw_data(self, row: Dict[str, Any]) -> Any:
@@ -117,33 +121,65 @@ class DataReader:
117
121
  return
118
122
 
119
123
 
120
- class ChemDataReader(DataReader):
124
+ class TokenIndexerReader(DataReader, ABC):
121
125
  """
122
- Data reader for chemical data using SMILES tokens.
126
+ Abstract base class for reading tokenized data and mapping tokens to unique indices.
123
127
 
124
- Args:
125
- collator_kwargs: Optional dictionary of keyword arguments for the collator.
126
- token_path: Optional path for the token file.
127
- kwargs: Additional keyword arguments.
128
+ This class maintains a cache of token-to-index mappings that can be extended during runtime,
129
+ and saves new tokens to a persistent file at the end of processing.
128
130
  """
129
131
 
130
- COLLATOR = RaggedCollator
131
-
132
- @classmethod
133
- def name(cls) -> str:
134
- """Returns the name of the data reader."""
135
- return "smiles_token"
136
-
137
132
  def __init__(self, *args, **kwargs):
138
133
  super().__init__(*args, **kwargs)
139
134
  with open(self.token_path, "r") as pk:
140
- self.cache = [x.strip() for x in pk]
135
+ self.cache: Dict[str, int] = {
136
+ token.strip(): idx for idx, token in enumerate(pk)
137
+ }
138
+ self._loaded_tokens_count = len(self.cache)
141
139
 
142
140
  def _get_token_index(self, token: str) -> int:
143
141
  """Returns a unique number for each token, automatically adds new tokens."""
144
142
  if not str(token) in self.cache:
145
- self.cache.append(str(token))
146
- return self.cache.index(str(token)) + EMBEDDING_OFFSET
143
+ self.cache[(str(token))] = len(self.cache)
144
+ return self.cache[str(token)] + EMBEDDING_OFFSET
145
+
146
+ def on_finish(self) -> None:
147
+ """
148
+ Saves the current cache of tokens to the token file.This method is called after all data processing is complete.
149
+ """
150
+ print(f"first 10 tokens: {list(islice(self.cache, 10))}")
151
+
152
+ total_tokens = len(self.cache)
153
+ if total_tokens > self._loaded_tokens_count:
154
+ print("New tokens added to the cache, Saving them to token file.....")
155
+
156
+ assert sys.version_info >= (
157
+ 3,
158
+ 7,
159
+ ), "This code requires Python 3.7 or higher."
160
+ # For python 3.7+, the standard dict type preserves insertion order, and is iterated over in same order
161
+ # https://docs.python.org/3/whatsnew/3.7.html#summary-release-highlights
162
+ # https://mail.python.org/pipermail/python-dev/2017-December/151283.html
163
+ new_tokens = list(
164
+ islice(self.cache, self._loaded_tokens_count, total_tokens)
165
+ )
166
+
167
+ with open(self.token_path, "a") as pk:
168
+ print(f"saving new {len(new_tokens)} tokens to {self.token_path}...")
169
+ pk.writelines([f"{c}\n" for c in new_tokens])
170
+
171
+
172
+ class ChemDataReader(TokenIndexerReader):
173
+ """
174
+ Data reader for chemical data using SMILES tokens.
175
+ """
176
+
177
+ COLLATOR = RaggedCollator
178
+
179
+ @classmethod
180
+ def name(cls) -> str:
181
+ """Returns the name of the data reader."""
182
+ return "smiles_token"
147
183
 
148
184
  def _read_data(self, raw_data: str) -> List[int]:
149
185
  """
@@ -157,15 +193,6 @@ class ChemDataReader(DataReader):
157
193
  """
158
194
  return [self._get_token_index(v[1]) for v in _tokenize(raw_data)]
159
195
 
160
- def on_finish(self) -> None:
161
- """
162
- Saves the current cache of tokens to the token file. This method is called after all data processing is complete.
163
- """
164
- with open(self.token_path, "w") as pk:
165
- print(f"saving {len(self.cache)} tokens to {self.token_path}...")
166
- print(f"first 10 tokens: {self.cache[:10]}")
167
- pk.writelines([f"{c}\n" for c in self.cache])
168
-
169
196
 
170
197
  class DeepChemDataReader(ChemDataReader):
171
198
  """
@@ -1,20 +1,21 @@
1
1
  import gc
2
- import sys
3
2
  import traceback
4
3
  from datetime import datetime
5
4
  from typing import List, LiteralString
6
5
 
6
+ import pandas as pd
7
7
  from torchmetrics.functional.classification import (
8
8
  multilabel_auroc,
9
9
  multilabel_average_precision,
10
10
  multilabel_f1_score,
11
11
  )
12
- from utils import *
13
12
 
14
13
  from chebai.loss.semantic import DisjointLoss
14
+ from chebai.models import Electra
15
15
  from chebai.preprocessing.datasets.base import _DynamicDataset
16
16
  from chebai.preprocessing.datasets.chebi import ChEBIOver100
17
17
  from chebai.preprocessing.datasets.pubchem import PubChemKMeans
18
+ from chebai.result.utils import *
18
19
 
19
20
  DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
20
21
 
@@ -122,7 +123,7 @@ def load_preds_labels(
122
123
  def get_label_names(data_module):
123
124
  if os.path.exists(os.path.join(data_module.processed_dir_main, "classes.txt")):
124
125
  with open(os.path.join(data_module.processed_dir_main, "classes.txt")) as fin:
125
- return [int(line.strip()) for line in fin]
126
+ return [line.strip() for line in fin]
126
127
  print(
127
128
  f"Failed to retrieve label names, {os.path.join(data_module.processed_dir_main, 'classes.txt')} not found"
128
129
  )
@@ -131,70 +132,97 @@ def get_label_names(data_module):
131
132
 
132
133
  def get_chebi_graph(data_module, label_names):
133
134
  if os.path.exists(os.path.join(data_module.raw_dir, "chebi.obo")):
134
- chebi_graph = data_module.extract_class_hierarchy(
135
+ chebi_graph = data_module._extract_class_hierarchy(
135
136
  os.path.join(data_module.raw_dir, "chebi.obo")
136
137
  )
137
- return chebi_graph.subgraph(label_names)
138
+ return chebi_graph.subgraph([int(n) for n in label_names])
138
139
  print(
139
140
  f"Failed to retrieve ChEBI graph, {os.path.join(data_module.raw_dir, 'chebi.obo')} not found"
140
141
  )
141
142
  return None
142
143
 
143
144
 
144
- def get_disjoint_groups():
145
- disjoints_owl_file = os.path.join("data", "chebi-disjoints.owl")
146
- with open(disjoints_owl_file, "r") as f:
147
- plaintext = f.read()
148
- segments = plaintext.split("<")
149
- disjoint_pairs = []
150
- left = None
151
- for seg in segments:
152
- if seg.startswith("rdf:Description ") or seg.startswith("owl:Class"):
153
- left = int(seg.split('rdf:about="&obo;CHEBI_')[1].split('"')[0])
154
- elif seg.startswith("owl:disjointWith"):
155
- right = int(seg.split('rdf:resource="&obo;CHEBI_')[1].split('"')[0])
156
- disjoint_pairs.append([left, right])
157
-
158
- disjoint_groups = []
159
- for seg in plaintext.split("<rdf:Description>"):
160
- if "owl;AllDisjointClasses" in seg:
161
- classes = seg.split('rdf:about="&obo;CHEBI_')[1:]
162
- classes = [int(c.split('"')[0]) for c in classes]
163
- disjoint_groups.append(classes)
145
+ def get_disjoint_groups(disjoint_files):
146
+ if disjoint_files is None:
147
+ disjoint_files = os.path.join("data", "chebi-disjoints.owl")
148
+ disjoint_pairs, disjoint_groups = [], []
149
+ for file in disjoint_files:
150
+ if file.split(".")[-1] == "csv":
151
+ disjoint_pairs += pd.read_csv(file, header=None).values.tolist()
152
+ elif file.split(".")[-1] == "owl":
153
+ with open(file, "r") as f:
154
+ plaintext = f.read()
155
+ segments = plaintext.split("<")
156
+ disjoint_pairs = []
157
+ left = None
158
+ for seg in segments:
159
+ if seg.startswith("rdf:Description ") or seg.startswith(
160
+ "owl:Class"
161
+ ):
162
+ left = int(seg.split('rdf:about="&obo;CHEBI_')[1].split('"')[0])
163
+ elif seg.startswith("owl:disjointWith"):
164
+ right = int(
165
+ seg.split('rdf:resource="&obo;CHEBI_')[1].split('"')[0]
166
+ )
167
+ disjoint_pairs.append([left, right])
168
+
169
+ disjoint_groups = []
170
+ for seg in plaintext.split("<rdf:Description>"):
171
+ if "owl;AllDisjointClasses" in seg:
172
+ classes = seg.split('rdf:about="&obo;CHEBI_')[1:]
173
+ classes = [int(c.split('"')[0]) for c in classes]
174
+ disjoint_groups.append(classes)
175
+ else:
176
+ raise NotImplementedError(
177
+ "Unsupported disjoint file format: " + file.split(".")[-1]
178
+ )
179
+
164
180
  disjoint_all = disjoint_pairs + disjoint_groups
165
181
  # one disjointness is commented out in the owl-file
166
182
  # (the correct way would be to parse the owl file and notice the comment symbols, but for this case, it should work)
167
- disjoint_all.remove([22729, 51880])
168
- print(f"Found {len(disjoint_all)} disjoint groups")
183
+ if [22729, 51880] in disjoint_all:
184
+ disjoint_all.remove([22729, 51880])
185
+ # print(f"Found {len(disjoint_all)} disjoint groups")
169
186
  return disjoint_all
170
187
 
171
188
 
172
189
  class PredictionSmoother:
173
190
  """Removes implication and disjointness violations from predictions"""
174
191
 
175
- def __init__(self, dataset):
176
- self.label_names = get_label_names(dataset)
192
+ def __init__(self, dataset, label_names=None, disjoint_files=None):
193
+ if label_names:
194
+ self.label_names = label_names
195
+ else:
196
+ self.label_names = get_label_names(dataset)
177
197
  self.chebi_graph = get_chebi_graph(dataset, self.label_names)
178
- self.disjoint_groups = get_disjoint_groups()
198
+ self.disjoint_groups = get_disjoint_groups(disjoint_files)
179
199
 
180
200
  def __call__(self, preds):
181
-
182
201
  preds_sum_orig = torch.sum(preds)
183
- print(f"Preds sum: {preds_sum_orig}")
184
- # eliminate implication violations by setting each prediction to maximum of its successors
185
202
  for i, label in enumerate(self.label_names):
186
203
  succs = [
187
- self.label_names.index(p) for p in self.chebi_graph.successors(label)
204
+ self.label_names.index(str(p))
205
+ for p in self.chebi_graph.successors(int(label))
188
206
  ] + [i]
189
207
  if len(succs) > 0:
208
+ if torch.max(preds[:, succs], dim=1).values > 0.5 and preds[:, i] < 0.5:
209
+ print(
210
+ f"Correcting prediction for {label} to max of subclasses {list(self.chebi_graph.successors(int(label)))}"
211
+ )
212
+ print(
213
+ f"Original pred: {preds[:, i]}, successors: {preds[:, succs]}"
214
+ )
190
215
  preds[:, i] = torch.max(preds[:, succs], dim=1).values
191
- print(f"Preds change (step 1): {torch.sum(preds) - preds_sum_orig}")
216
+ if torch.sum(preds) != preds_sum_orig:
217
+ print(f"Preds change (step 1): {torch.sum(preds) - preds_sum_orig}")
192
218
  preds_sum_orig = torch.sum(preds)
193
219
  # step 2: eliminate disjointness violations: for group of disjoint classes, set all except max to 0.49 (if it is not already lower)
194
220
  preds_bounded = torch.min(preds, torch.ones_like(preds) * 0.49)
195
221
  for disj_group in self.disjoint_groups:
196
222
  disj_group = [
197
- self.label_names.index(g) for g in disj_group if g in self.label_names
223
+ self.label_names.index(str(g))
224
+ for g in disj_group
225
+ if g in self.label_names
198
226
  ]
199
227
  if len(disj_group) > 1:
200
228
  old_preds = preds[:, disj_group]
@@ -211,14 +239,12 @@ class PredictionSmoother:
211
239
  print(
212
240
  f"disjointness group {[self.label_names[d] for d in disj_group]} changed {samples_changed} samples"
213
241
  )
214
- print(
215
- f"Preds change after disjointness (step 2): {torch.sum(preds) - preds_sum_orig}"
216
- )
217
242
  preds_sum_orig = torch.sum(preds)
218
243
  # step 3: disjointness violation removal may have caused new implication inconsistencies -> set each prediction to min of predecessors
219
244
  for i, label in enumerate(self.label_names):
220
245
  predecessors = [i] + [
221
- self.label_names.index(p) for p in self.chebi_graph.predecessors(label)
246
+ self.label_names.index(str(p))
247
+ for p in self.chebi_graph.predecessors(int(label))
222
248
  ]
223
249
  lowest_predecessors = torch.min(preds[:, predecessors], dim=1)
224
250
  preds[:, i] = lowest_predecessors.values