ddi-fw 0.0.217__tar.gz → 0.0.219__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/PKG-INFO +1 -1
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/pyproject.toml +1 -1
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/core.py +1 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/ddi_mdl/base.py +24 -8
- ddi_fw-0.0.219/src/ddi_fw/datasets/mdf_sa_ddi/base.py +375 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/ml/__init__.py +2 -1
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/ml/ml_helper.py +26 -30
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/ml/model_wrapper.py +0 -1
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/ml/tensorflow_wrapper.py +165 -89
- ddi_fw-0.0.219/src/ddi_fw/ml/tracking_service.py +194 -0
- ddi_fw-0.0.217/src/ddi_fw/pipeline/multi_pipeline_v2.py → ddi_fw-0.0.219/src/ddi_fw/pipeline/multi_pipeline.py +8 -11
- ddi_fw-0.0.219/src/ddi_fw/pipeline/pipeline.py +148 -0
- ddi_fw-0.0.219/src/ddi_fw/utils/utils.py +117 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw.egg-info/PKG-INFO +1 -1
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw.egg-info/SOURCES.txt +2 -1
- ddi_fw-0.0.217/src/ddi_fw/datasets/mdf_sa_ddi/base.py +0 -164
- ddi_fw-0.0.217/src/ddi_fw/pipeline/pipeline.py +0 -206
- ddi_fw-0.0.217/src/ddi_fw/utils/utils.py +0 -117
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/README.md +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/setup.cfg +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/__init__.py +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/dataset_splitter.py +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/db_utils.py +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/ddi_mdl/data/event.db +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/ddi_mdl/debug.log +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/ddi_mdl/indexes/test_indexes.txt +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/ddi_mdl/indexes/train_fold_0.txt +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/ddi_mdl/indexes/train_fold_1.txt +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/ddi_mdl/indexes/train_fold_2.txt +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/ddi_mdl/indexes/train_fold_3.txt +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/ddi_mdl/indexes/train_fold_4.txt +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/ddi_mdl/indexes/train_indexes.txt +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/ddi_mdl/indexes/validation_fold_0.txt +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/ddi_mdl/indexes/validation_fold_1.txt +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/ddi_mdl/indexes/validation_fold_2.txt +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/ddi_mdl/indexes/validation_fold_3.txt +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/ddi_mdl/indexes/validation_fold_4.txt +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/ddi_mdl/indexes_old/test_indexes.txt +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/ddi_mdl/indexes_old/train_fold_0.txt +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/ddi_mdl/indexes_old/train_fold_1.txt +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/ddi_mdl/indexes_old/train_fold_2.txt +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/ddi_mdl/indexes_old/train_fold_3.txt +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/ddi_mdl/indexes_old/train_fold_4.txt +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/ddi_mdl/indexes_old/train_indexes.txt +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/ddi_mdl/indexes_old/validation_fold_0.txt +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/ddi_mdl/indexes_old/validation_fold_1.txt +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/ddi_mdl/indexes_old/validation_fold_2.txt +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/ddi_mdl/indexes_old/validation_fold_3.txt +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/ddi_mdl/indexes_old/validation_fold_4.txt +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/ddi_mdl/readme.md +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/ddi_mdl_text/base.py +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/ddi_mdl_text/data/event.db +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/ddi_mdl_text/indexes/test_indexes.txt +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/ddi_mdl_text/indexes/train_fold_0.txt +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/ddi_mdl_text/indexes/train_fold_1.txt +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/ddi_mdl_text/indexes/train_fold_2.txt +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/ddi_mdl_text/indexes/train_fold_3.txt +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/ddi_mdl_text/indexes/train_fold_4.txt +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/ddi_mdl_text/indexes/train_indexes.txt +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/ddi_mdl_text/indexes/validation_fold_0.txt +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/ddi_mdl_text/indexes/validation_fold_1.txt +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/ddi_mdl_text/indexes/validation_fold_2.txt +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/ddi_mdl_text/indexes/validation_fold_3.txt +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/ddi_mdl_text/indexes/validation_fold_4.txt +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/mdf_sa_ddi/__init__.py +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/mdf_sa_ddi/df_extraction_cleanxiaoyu50.csv +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/mdf_sa_ddi/drug_information_del_noDDIxiaoyu50.csv +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/mdf_sa_ddi/indexes/test_indexes.txt +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/mdf_sa_ddi/indexes/train_fold_0.txt +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/mdf_sa_ddi/indexes/train_fold_1.txt +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/mdf_sa_ddi/indexes/train_fold_2.txt +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/mdf_sa_ddi/indexes/train_fold_3.txt +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/mdf_sa_ddi/indexes/train_fold_4.txt +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/mdf_sa_ddi/indexes/train_indexes.txt +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/mdf_sa_ddi/indexes/validation_fold_0.txt +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/mdf_sa_ddi/indexes/validation_fold_1.txt +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/mdf_sa_ddi/indexes/validation_fold_2.txt +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/mdf_sa_ddi/indexes/validation_fold_3.txt +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/mdf_sa_ddi/indexes/validation_fold_4.txt +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/mdf_sa_ddi/mdf-sa-ddi.zip +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/datasets/setup_._py +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/drugbank/__init__.py +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/drugbank/drugbank.xsd +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/drugbank/drugbank_parser.py +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/drugbank/drugbank_processor.py +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/drugbank/drugbank_processor_org.py +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/drugbank/event_extractor.py +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/langchain/__init__.py +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/langchain/embeddings.py +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/langchain/sentence_splitter.py +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/langchain/storage.py +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/ml/evaluation_helper.py +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/ml/pytorch_wrapper.py +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/ner/__init__.py +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/ner/mmlrestclient.py +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/ner/ner.py +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/pipeline/__init__.py +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/pipeline/multi_modal_combination_strategy.py +0 -0
- /ddi_fw-0.0.217/src/ddi_fw/pipeline/multi_pipeline.py → /ddi_fw-0.0.219/src/ddi_fw/pipeline/multi_pipeline_org.py +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/pipeline/ner_pipeline.py +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/utils/__init__.py +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/utils/categorical_data_encoding_checker.py +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/utils/enums.py +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/utils/json_helper.py +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/utils/kaggle.py +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/utils/numpy_utils.py +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/utils/package_helper.py +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/utils/py7zr_helper.py +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/utils/zip_helper.py +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/vectorization/__init__.py +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/vectorization/feature_vector_generation.py +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw/vectorization/idf_helper.py +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw.egg-info/dependency_links.txt +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw.egg-info/requires.txt +0 -0
- {ddi_fw-0.0.217 → ddi_fw-0.0.219}/src/ddi_fw.egg-info/top_level.txt +0 -0
@@ -73,6 +73,7 @@ class BaseDataset(BaseModel, abc.ABC):
|
|
73
73
|
train_idx_arr: Optional[List[np.ndarray]] = None
|
74
74
|
val_idx_arr: Optional[List[np.ndarray]] = None
|
75
75
|
columns: List[str] = []
|
76
|
+
additional_config: Optional[Dict[str, Any]] = None
|
76
77
|
|
77
78
|
class Config:
|
78
79
|
arbitrary_types_allowed = True
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import pathlib
|
2
|
-
from typing import List, Optional, Tuple
|
2
|
+
from typing import Any, List, Optional, Tuple
|
3
3
|
from ddi_fw.datasets.core import BaseDataset, TextDatasetMixin, generate_sim_matrices_new, generate_vectors
|
4
4
|
from ddi_fw.datasets.db_utils import create_connection
|
5
5
|
import numpy as np
|
@@ -9,6 +9,8 @@ from abc import ABC, abstractmethod
|
|
9
9
|
from sklearn.preprocessing import LabelBinarizer
|
10
10
|
import logging
|
11
11
|
|
12
|
+
from ddi_fw.ner.ner import CTakesNER
|
13
|
+
|
12
14
|
|
13
15
|
try:
|
14
16
|
from ddi_fw.vectorization import IDF
|
@@ -49,6 +51,7 @@ class DDIMDLDataset(BaseDataset,TextDatasetMixin):
|
|
49
51
|
tui_threshold: float | None = None
|
50
52
|
cui_threshold: float | None = None
|
51
53
|
entities_threshold: float | None = None
|
54
|
+
_ner_threshold: dict[str,Any] |None = None
|
52
55
|
|
53
56
|
# @model_validator
|
54
57
|
|
@@ -63,6 +66,18 @@ class DDIMDLDataset(BaseDataset,TextDatasetMixin):
|
|
63
66
|
|
64
67
|
super().__init__(**kwargs)
|
65
68
|
|
69
|
+
# self.additional_config = kwargs.get('dataset_additional_config', {})
|
70
|
+
if self.additional_config:
|
71
|
+
ner = self.additional_config.get('ner', {})
|
72
|
+
ner_data_file = ner.get('data_file', None)
|
73
|
+
self._ner_threshold = ner.get('thresholds', None)
|
74
|
+
# if self.ner_threshold:
|
75
|
+
# for k, v in self.ner_threshold.items():
|
76
|
+
# kwargs[k] = v
|
77
|
+
|
78
|
+
self.ner_df = CTakesNER(df=None).load(
|
79
|
+
filename=ner_data_file) if ner_data_file else None
|
80
|
+
|
66
81
|
columns = kwargs['columns']
|
67
82
|
if columns:
|
68
83
|
chemical_property_columns = []
|
@@ -155,13 +170,14 @@ class DDIMDLDataset(BaseDataset,TextDatasetMixin):
|
|
155
170
|
|
156
171
|
# for key in filtered_ner_df.keys():
|
157
172
|
for key in self.ner_columns:
|
158
|
-
threshold = 0
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
173
|
+
threshold = self._ner_threshold.get(key, 0) if self._ner_threshold else 0
|
174
|
+
# threshold = 0
|
175
|
+
# if key.startswith('tui'):
|
176
|
+
# threshold = self.tui_threshold
|
177
|
+
# if key.startswith('cui'):
|
178
|
+
# threshold = self.cui_threshold
|
179
|
+
# if key.startswith('entities'):
|
180
|
+
# threshold = self.entities_threshold
|
165
181
|
combined_df[key] = filtered_ner_df[key]
|
166
182
|
valid_codes = idf_scores_df[idf_scores_df[key]
|
167
183
|
> threshold].index
|
@@ -0,0 +1,375 @@
|
|
1
|
+
import os
|
2
|
+
import pathlib
|
3
|
+
from typing import Any, List, Optional, Tuple
|
4
|
+
from ddi_fw.datasets.core import BaseDataset, TextDatasetMixin, generate_sim_matrices_new, generate_vectors
|
5
|
+
from ddi_fw.datasets.db_utils import create_connection
|
6
|
+
import numpy as np
|
7
|
+
import pandas as pd
|
8
|
+
from pydantic import BaseModel, Field, model_validator, root_validator
|
9
|
+
from abc import ABC, abstractmethod
|
10
|
+
from sklearn.preprocessing import LabelBinarizer
|
11
|
+
import logging
|
12
|
+
|
13
|
+
from ddi_fw.ner.ner import CTakesNER
|
14
|
+
from ddi_fw.utils.zip_helper import ZipHelper
|
15
|
+
|
16
|
+
|
17
|
+
try:
|
18
|
+
from ddi_fw.vectorization import IDF
|
19
|
+
except ImportError:
|
20
|
+
raise ImportError(
|
21
|
+
"Failed to import vectorization module. Ensure that the module exists and is correctly installed. ")
|
22
|
+
|
23
|
+
logger = logging.getLogger(__name__)
|
24
|
+
|
25
|
+
# Constants for embedding, chemical properties, and NER columns
|
26
|
+
LIST_OF_EMBEDDING_COLUMNS = [
|
27
|
+
'all_text', 'description', 'synthesis_reference', 'indication',
|
28
|
+
'pharmacodynamics', 'mechanism_of_action', 'toxicity', 'metabolism',
|
29
|
+
'absorption', 'half_life', 'protein_binding', 'route_of_elimination',
|
30
|
+
'volume_of_distribution', 'clearance'
|
31
|
+
]
|
32
|
+
|
33
|
+
LIST_OF_CHEMICAL_PROPERTY_COLUMNS = ['enzyme', 'target', 'smile']
|
34
|
+
LIST_OF_NER_COLUMNS = ['tui', 'cui', 'entities']
|
35
|
+
|
36
|
+
HERE = pathlib.Path(__file__).resolve().parent
|
37
|
+
|
38
|
+
class MDFSADDIDataset(BaseDataset,TextDatasetMixin):
|
39
|
+
# def __init__(self, embedding_size,
|
40
|
+
# embedding_dict,
|
41
|
+
# embeddings_pooling_strategy: PoolingStrategy,
|
42
|
+
# ner_df,
|
43
|
+
# chemical_property_columns=['enzyme',
|
44
|
+
# 'target',
|
45
|
+
# 'smile'],
|
46
|
+
# embedding_columns=[],
|
47
|
+
# ner_columns=[],
|
48
|
+
# **kwargs):
|
49
|
+
|
50
|
+
# columns = kwargs['columns']
|
51
|
+
# if columns:
|
52
|
+
# chemical_property_columns = []
|
53
|
+
# embedding_columns=[]
|
54
|
+
# ner_columns=[]
|
55
|
+
# for column in columns:
|
56
|
+
# if column in list_of_chemical_property_columns:
|
57
|
+
# chemical_property_columns.append(column)
|
58
|
+
# elif column in list_of_embedding_columns:
|
59
|
+
# embedding_columns.append(column)
|
60
|
+
# elif column in list_of_ner_columns:
|
61
|
+
# ner_columns.append(column)
|
62
|
+
# # elif column == 'smile_2':
|
63
|
+
# # continue
|
64
|
+
# else:
|
65
|
+
# raise Exception(f"{column} is not related this dataset")
|
66
|
+
|
67
|
+
|
68
|
+
# super().__init__(embedding_size=embedding_size,
|
69
|
+
# embedding_dict=embedding_dict,
|
70
|
+
# embeddings_pooling_strategy=embeddings_pooling_strategy,
|
71
|
+
# ner_df=ner_df,
|
72
|
+
# chemical_property_columns=chemical_property_columns,
|
73
|
+
# embedding_columns=embedding_columns,
|
74
|
+
# ner_columns=ner_columns,
|
75
|
+
# **kwargs)
|
76
|
+
|
77
|
+
# db_zip_path = HERE.joinpath('mdf-sa-ddi.zip')
|
78
|
+
# db_path = HERE.joinpath('mdf-sa-ddi.db')
|
79
|
+
# if not os.path.exists(db_zip_path):
|
80
|
+
# self.__to_db__(db_path)
|
81
|
+
# else:
|
82
|
+
# ZipHelper().extract(
|
83
|
+
# input_path=str(HERE), output_path=str(HERE))
|
84
|
+
# conn = create_connection(db_path)
|
85
|
+
# self.drugs_df = select_all_drugs_as_dataframe(conn)
|
86
|
+
# self.ddis_df = select_all_events_as_dataframe(conn)
|
87
|
+
# # kwargs = {'index_path': str(HERE.joinpath('indexes'))}
|
88
|
+
# kwargs['index_path'] = str(HERE.joinpath('indexes'))
|
89
|
+
|
90
|
+
# self.index_path = kwargs.get('index_path')
|
91
|
+
|
92
|
+
dataset_name: str = "MDFSADDIDataset"
|
93
|
+
index_path: str = Field(default_factory=lambda: str(
|
94
|
+
pathlib.Path(__file__).resolve().parent.joinpath('indexes')))
|
95
|
+
# drugs_df: pd.DataFrame = Field(default_factory=pd.DataFrame)
|
96
|
+
# ddis_df: pd.DataFrame = Field(default_factory=pd.DataFrame)
|
97
|
+
drugs_df: Optional[pd.DataFrame] = None
|
98
|
+
ddis_df: Optional[pd.DataFrame] = None
|
99
|
+
|
100
|
+
chemical_property_columns: list[str] = Field(
|
101
|
+
default_factory=lambda: LIST_OF_CHEMICAL_PROPERTY_COLUMNS)
|
102
|
+
embedding_columns: list[str] = Field(default_factory=list)
|
103
|
+
ner_columns: list[str] = Field(default_factory=list)
|
104
|
+
ner_df: pd.DataFrame | None = None
|
105
|
+
tui_threshold: float | None = None
|
106
|
+
cui_threshold: float | None = None
|
107
|
+
entities_threshold: float | None = None
|
108
|
+
_ner_threshold: dict[str,Any] |None= None
|
109
|
+
|
110
|
+
# @model_validator
|
111
|
+
|
112
|
+
def validate_columns(self, values):
|
113
|
+
if not set(values['chemical_property_columns']).issubset(LIST_OF_CHEMICAL_PROPERTY_COLUMNS):
|
114
|
+
raise ValueError("Invalid chemical property columns")
|
115
|
+
if not set(values['ner_columns']).issubset(LIST_OF_NER_COLUMNS):
|
116
|
+
raise ValueError("Invalid NER columns")
|
117
|
+
return values
|
118
|
+
|
119
|
+
def __init__(self, **kwargs):
|
120
|
+
|
121
|
+
super().__init__(**kwargs)
|
122
|
+
|
123
|
+
# self.additional_config = kwargs.get('dataset_additional_config', {})
|
124
|
+
if self.additional_config:
|
125
|
+
ner = self.additional_config.get('ner', {})
|
126
|
+
ner_data_file = ner.get('data_file', None)
|
127
|
+
self._ner_threshold = ner.get('thresholds', None)
|
128
|
+
# if self.ner_threshold:
|
129
|
+
# for k, v in self.ner_threshold.items():
|
130
|
+
# kwargs[k] = v
|
131
|
+
|
132
|
+
self.ner_df = CTakesNER(df=None).load(
|
133
|
+
filename=ner_data_file) if ner_data_file else None
|
134
|
+
|
135
|
+
columns = kwargs['columns']
|
136
|
+
if columns:
|
137
|
+
chemical_property_columns = []
|
138
|
+
embedding_columns = []
|
139
|
+
ner_columns = []
|
140
|
+
for column in columns:
|
141
|
+
if column in LIST_OF_CHEMICAL_PROPERTY_COLUMNS:
|
142
|
+
chemical_property_columns.append(column)
|
143
|
+
elif column in LIST_OF_EMBEDDING_COLUMNS:
|
144
|
+
embedding_columns.append(column)
|
145
|
+
elif column in LIST_OF_NER_COLUMNS:
|
146
|
+
ner_columns.append(column)
|
147
|
+
else:
|
148
|
+
raise Exception(f"{column} is not related this dataset")
|
149
|
+
|
150
|
+
self.chemical_property_columns = chemical_property_columns
|
151
|
+
self.embedding_columns = embedding_columns
|
152
|
+
self.ner_columns = ner_columns
|
153
|
+
self.columns = [] # these variable is modified in prep method
|
154
|
+
|
155
|
+
|
156
|
+
db_zip_path = HERE.joinpath('mdf-sa-ddi.zip')
|
157
|
+
db_path = HERE.joinpath('mdf-sa-ddi.db')
|
158
|
+
if not os.path.exists(db_zip_path):
|
159
|
+
self.__to_db__(db_path)
|
160
|
+
else:
|
161
|
+
ZipHelper().extract(
|
162
|
+
input_path=str(HERE), output_path=str(HERE))
|
163
|
+
conn = create_connection(db_path.absolute().as_posix())
|
164
|
+
self.drugs_df = select_all_drugs_as_dataframe(conn)
|
165
|
+
self.ddis_df = select_all_events_as_dataframe(conn)
|
166
|
+
# kwargs = {'index_path': str(HERE.joinpath('indexes'))}
|
167
|
+
|
168
|
+
|
169
|
+
self.class_column = 'event_category'
|
170
|
+
|
171
|
+
self.__similarity_related_columns__ = []
|
172
|
+
self.__similarity_related_columns__.extend(
|
173
|
+
self.chemical_property_columns)
|
174
|
+
self.__similarity_related_columns__.extend(self.ner_columns)
|
175
|
+
logger.info(f'{self.dataset_name} is initialized')
|
176
|
+
|
177
|
+
def __to_db__(self, db_path):
|
178
|
+
conn = create_connection(db_path)
|
179
|
+
drugs_path = HERE.joinpath('drug_information_del_noDDIxiaoyu50.csv')
|
180
|
+
ddis_path = HERE.joinpath('df_extraction_cleanxiaoyu50.csv')
|
181
|
+
self.drugs_df = pd.read_csv(drugs_path)
|
182
|
+
self.ddis_df = pd.read_csv(ddis_path)
|
183
|
+
self.drugs_df.drop(columns="Unnamed: 0", inplace=True)
|
184
|
+
self.ddis_df.drop(columns="Unnamed: 0", inplace=True)
|
185
|
+
|
186
|
+
self.ddis_df.rename(
|
187
|
+
columns={"drugA": "name1", "drugB": "name2"}, inplace=True)
|
188
|
+
self.ddis_df['event_category'] = self.ddis_df['mechanism'] + \
|
189
|
+
' ' + self.ddis_df['action']
|
190
|
+
|
191
|
+
reverse_ddis_df = pd.DataFrame()
|
192
|
+
reverse_ddis_df['id1'] = self.ddis_df['id2']
|
193
|
+
reverse_ddis_df['name1'] = self.ddis_df['name2']
|
194
|
+
reverse_ddis_df['id2'] = self.ddis_df['id1']
|
195
|
+
reverse_ddis_df['name2'] = self.ddis_df['name1']
|
196
|
+
reverse_ddis_df['event_category'] = self.ddis_df['event_category']
|
197
|
+
|
198
|
+
self.ddis_df = pd.concat(
|
199
|
+
[self.ddis_df, reverse_ddis_df], ignore_index=True)
|
200
|
+
|
201
|
+
drug_name_id_pairs = {}
|
202
|
+
for idx, row in self.drugs_df.iterrows():
|
203
|
+
drug_name_id_pairs[row['name']] = row['id']
|
204
|
+
|
205
|
+
# id1,id2
|
206
|
+
|
207
|
+
def lambda_fnc1(column):
|
208
|
+
return drug_name_id_pairs[column]
|
209
|
+
# def lambda_fnc2(row):
|
210
|
+
# x = self.drugs_df[self.drugs_df['name'] == row['name2']]
|
211
|
+
# return x['id']
|
212
|
+
|
213
|
+
self.ddis_df['id1'] = self.ddis_df['name1'].apply(
|
214
|
+
lambda_fnc1) # , axis=1
|
215
|
+
self.ddis_df['id2'] = self.ddis_df['name2'].apply(
|
216
|
+
lambda_fnc1) # , axis=1
|
217
|
+
if conn:
|
218
|
+
self.drugs_df.to_sql('drug', conn, if_exists='replace', index=False)
|
219
|
+
self.ddis_df.to_sql('event', conn, if_exists='replace', index=False)
|
220
|
+
ZipHelper().zip_single_file(
|
221
|
+
file_path=db_path, output_path=HERE, zip_name='mdf-sa-ddi')
|
222
|
+
|
223
|
+
def prep(self):
|
224
|
+
# self.load_drugs_and_events()
|
225
|
+
if self.drugs_df is None or self.ddis_df is None:
|
226
|
+
raise Exception("There is no data")
|
227
|
+
|
228
|
+
drug_ids = self.drugs_df['id'].to_list()
|
229
|
+
|
230
|
+
filtered_df = self.drugs_df
|
231
|
+
combined_df = filtered_df.copy()
|
232
|
+
|
233
|
+
if self.ner_df is not None and not self.ner_df.empty:
|
234
|
+
filtered_ner_df = self.ner_df[self.ner_df['drugbank_id'].isin(
|
235
|
+
drug_ids)]
|
236
|
+
filtered_ner_df = self.ner_df.copy()
|
237
|
+
|
238
|
+
# TODO: eğer kullanılan veri setinde tui, cui veya entity bilgileri yoksa o veri setine bu sütunları eklemek için aşağısı gerekli
|
239
|
+
|
240
|
+
# idf_calc = IDF(filtered_ner_df, [f for f in filtered_ner_df.keys()])
|
241
|
+
idf_calc = IDF(filtered_ner_df, self.ner_columns)
|
242
|
+
idf_calc.calculate()
|
243
|
+
idf_scores_df = idf_calc.to_dataframe()
|
244
|
+
|
245
|
+
# for key in filtered_ner_df.keys():
|
246
|
+
for key in self.ner_columns:
|
247
|
+
threshold = self._ner_threshold.get(key, 0) if self._ner_threshold else 0
|
248
|
+
# threshold = 0
|
249
|
+
# if key.startswith('tui'):
|
250
|
+
# threshold = self.tui_threshold
|
251
|
+
# if key.startswith('cui'):
|
252
|
+
# threshold = self.cui_threshold
|
253
|
+
# if key.startswith('entities'):
|
254
|
+
# threshold = self.entities_threshold
|
255
|
+
combined_df[key] = filtered_ner_df[key]
|
256
|
+
valid_codes = idf_scores_df[idf_scores_df[key]
|
257
|
+
> threshold].index
|
258
|
+
|
259
|
+
# print(f'{key}: valid code size = {len(valid_codes)}')
|
260
|
+
combined_df[key] = combined_df[key].apply(lambda items:
|
261
|
+
[item for item in items if item in valid_codes])
|
262
|
+
|
263
|
+
moved_columns = ['id']
|
264
|
+
moved_columns.extend(self.__similarity_related_columns__)
|
265
|
+
chemical_properties_df = combined_df[moved_columns]
|
266
|
+
|
267
|
+
chemical_properties_df = chemical_properties_df.fillna("").apply(list)
|
268
|
+
|
269
|
+
# generate vectors dictionary içinde ndarray dönecek
|
270
|
+
generated_vectors = generate_vectors(
|
271
|
+
chemical_properties_df, self.__similarity_related_columns__)
|
272
|
+
|
273
|
+
# TODO if necessary
|
274
|
+
similarity_matrices = generate_sim_matrices_new(
|
275
|
+
chemical_properties_df, generated_vectors, self.__similarity_related_columns__, key_column="id")
|
276
|
+
|
277
|
+
event_categories = self.ddis_df['event_category']
|
278
|
+
labels = event_categories.tolist()
|
279
|
+
lb = LabelBinarizer()
|
280
|
+
lb.fit(labels)
|
281
|
+
classes = lb.transform(labels)
|
282
|
+
|
283
|
+
def similarity_lambda_fnc(row, value):
|
284
|
+
if row['id1'] in value:
|
285
|
+
return value[row['id1']]
|
286
|
+
|
287
|
+
def lambda_fnc(row: pd.Series, value) -> Optional[np.float16]:
|
288
|
+
if row['id1'] in value and row['id2'] in value:
|
289
|
+
return np.float16(np.hstack(
|
290
|
+
(value[row['id1']], value[row['id2']])))
|
291
|
+
return None
|
292
|
+
# return np.hstack(
|
293
|
+
# (value[row['id1']], value[row['id2']]), dtype=np.float16)
|
294
|
+
|
295
|
+
def x_fnc(row, embeddings_after_pooling):
|
296
|
+
if row['id1'] in embeddings_after_pooling:
|
297
|
+
v1 = embeddings_after_pooling[row['id1']]
|
298
|
+
else:
|
299
|
+
v1 = np.zeros(self.embedding_size)
|
300
|
+
if row['id2'] in embeddings_after_pooling:
|
301
|
+
v2 = embeddings_after_pooling[row['id2']]
|
302
|
+
else:
|
303
|
+
v2 = np.zeros(self.embedding_size)
|
304
|
+
return np.float16(np.hstack(
|
305
|
+
(v1, v2)))
|
306
|
+
|
307
|
+
for key, value in similarity_matrices.items():
|
308
|
+
|
309
|
+
print(f'sim matrix: {key}')
|
310
|
+
self.ddis_df[key] = self.ddis_df.apply(
|
311
|
+
lambda_fnc, args=(value,), axis=1)
|
312
|
+
self.columns.append(key)
|
313
|
+
print(self.ddis_df[key].head())
|
314
|
+
if isinstance(self, TextDatasetMixin):
|
315
|
+
if self.embedding_dict is not None:
|
316
|
+
for embedding_column in self.embedding_columns:
|
317
|
+
print(f"concat {embedding_column} embeddings")
|
318
|
+
embeddings_after_pooling = {k: self.pooling_strategy.apply(
|
319
|
+
v) for k, v in self.embedding_dict[embedding_column].items()}
|
320
|
+
# column_embeddings_dict = embedding_values[embedding_column]
|
321
|
+
self.ddis_df[embedding_column+'_embedding'] = self.ddis_df.apply(
|
322
|
+
x_fnc, args=(embeddings_after_pooling,), axis=1)
|
323
|
+
self.columns.append(embedding_column+'_embedding')
|
324
|
+
|
325
|
+
dataframe = self.ddis_df.copy()
|
326
|
+
if not isinstance(classes, (list, pd.Series, np.ndarray)):
|
327
|
+
raise TypeError(
|
328
|
+
"classes must be an iterable (list, Series, or ndarray)")
|
329
|
+
|
330
|
+
if len(classes) != len(dataframe):
|
331
|
+
raise ValueError(
|
332
|
+
"Length of classes must match the number of rows in the DataFrame")
|
333
|
+
|
334
|
+
dataframe[self.class_column] = list(classes)
|
335
|
+
self.set_dataframe(dataframe)
|
336
|
+
|
337
|
+
|
338
|
+
def select_all_drugs(conn):
|
339
|
+
cur = conn.cursor()
|
340
|
+
cur.execute(
|
341
|
+
'''select "index", id, name, target, enzyme, smile from drug''')
|
342
|
+
rows = cur.fetchall()
|
343
|
+
return rows
|
344
|
+
|
345
|
+
|
346
|
+
def select_all_drugs_as_dataframe(conn):
|
347
|
+
headers = ['index', 'id', 'name', 'target', 'enzyme', 'smile']
|
348
|
+
rows = select_all_drugs(conn)
|
349
|
+
df = pd.DataFrame(columns=headers, data=rows)
|
350
|
+
df['enzyme'] = df['enzyme'].apply(lambda x: x.split('|'))
|
351
|
+
df['target'] = df['target'].apply(lambda x: x.split('|'))
|
352
|
+
df['smile'] = df['smile'].apply(lambda x: x.split('|'))
|
353
|
+
return df
|
354
|
+
|
355
|
+
|
356
|
+
def select_all_events(conn):
|
357
|
+
"""
|
358
|
+
Query all rows in the event table
|
359
|
+
:param conn: the Connection object
|
360
|
+
:return:
|
361
|
+
"""
|
362
|
+
cur = conn.cursor()
|
363
|
+
cur.execute('''
|
364
|
+
select event."index", id1, name1, id2, name2, mechanism, action, event_category from event
|
365
|
+
''')
|
366
|
+
|
367
|
+
rows = cur.fetchall()
|
368
|
+
return rows
|
369
|
+
|
370
|
+
|
371
|
+
def select_all_events_as_dataframe(conn):
|
372
|
+
headers = ["index", "id1", "name1", "id2",
|
373
|
+
"name2", "mechanism", "action", "event_category"]
|
374
|
+
rows = select_all_events(conn)
|
375
|
+
return pd.DataFrame(columns=headers, data=rows)
|
@@ -2,4 +2,5 @@ from .ml_helper import MultiModalRunner
|
|
2
2
|
from .model_wrapper import ModelWrapper,Result
|
3
3
|
from .tensorflow_wrapper import TFModelWrapper
|
4
4
|
from .pytorch_wrapper import PTModelWrapper
|
5
|
-
from .evaluation_helper import evaluate
|
5
|
+
from .evaluation_helper import evaluate
|
6
|
+
from .tracking_service import TrackingService
|
@@ -1,23 +1,9 @@
|
|
1
|
-
from typing import Callable, Dict, List, Tuple
|
2
|
-
from matplotlib import pyplot as plt
|
3
1
|
from ddi_fw.ml.model_wrapper import Result
|
4
2
|
from ddi_fw.ml.pytorch_wrapper import PTModelWrapper
|
5
3
|
from ddi_fw.ml.tensorflow_wrapper import TFModelWrapper
|
6
4
|
from ddi_fw.utils.package_helper import get_import
|
7
|
-
import tensorflow as tf
|
8
|
-
from tensorflow.python import keras
|
9
|
-
from tensorflow.python.keras import Model, Sequential
|
10
|
-
from tensorflow.python.keras.layers import Dense, Dropout, Input, Activation
|
11
|
-
from tensorflow.python.keras.callbacks import EarlyStopping
|
12
|
-
from sklearn.model_selection import train_test_split, KFold, StratifiedKFold
|
13
5
|
import numpy as np
|
14
|
-
|
15
|
-
import mlflow
|
16
|
-
from mlflow.utils.autologging_utils import batch_metrics_logger
|
17
|
-
import time
|
18
|
-
|
19
|
-
from mlflow.models import infer_signature
|
20
|
-
from ddi_fw.ml.evaluation_helper import Metrics, evaluate
|
6
|
+
from ddi_fw.ml.evaluation_helper import evaluate
|
21
7
|
|
22
8
|
# import tf2onnx
|
23
9
|
# import onnx
|
@@ -32,16 +18,16 @@ import ddi_fw.utils as utils
|
|
32
18
|
|
33
19
|
class MultiModalRunner:
|
34
20
|
# todo model related parameters to config
|
35
|
-
def __init__(self, library, multi_modal, default_model,
|
21
|
+
def __init__(self, library, multi_modal, default_model, tracking_service):
|
36
22
|
self.library = library
|
37
23
|
self.multi_modal = multi_modal
|
38
24
|
self.default_model = default_model
|
39
|
-
self.
|
25
|
+
self.tracking_service = tracking_service
|
40
26
|
self.result = Result()
|
41
27
|
|
42
|
-
def _mlflow_(self, func: Callable):
|
43
|
-
|
44
|
-
|
28
|
+
# def _mlflow_(self, func: Callable):
|
29
|
+
# if self.use_mlflow:
|
30
|
+
# func()
|
45
31
|
|
46
32
|
def set_data(self, items, train_idx_arr, val_idx_arr, y_test_label):
|
47
33
|
self.items = items
|
@@ -74,7 +60,7 @@ class MultiModalRunner:
|
|
74
60
|
kwargs = m.get('params')
|
75
61
|
T = self.__create_model(self.library)
|
76
62
|
single_modal = T(self.date, name, model_type,
|
77
|
-
|
63
|
+
tracking_service=self.tracking_service, **kwargs)
|
78
64
|
|
79
65
|
if input is not None and inputs is not None:
|
80
66
|
raise Exception("input and inputs should not be used together")
|
@@ -110,7 +96,7 @@ class MultiModalRunner:
|
|
110
96
|
name = item[0]
|
111
97
|
T = self.__create_model(self.library)
|
112
98
|
single_modal = T(self.date, name, model_type,
|
113
|
-
|
99
|
+
tracking_service=self.tracking_service, **kwargs)
|
114
100
|
single_modal.set_data(
|
115
101
|
self.train_idx_arr, self.val_idx_arr, item[1], item[2], item[3], item[4])
|
116
102
|
|
@@ -130,9 +116,12 @@ class MultiModalRunner:
|
|
130
116
|
combinations = []
|
131
117
|
for i in range(2, len(l) + 1):
|
132
118
|
combinations.extend(list(itertools.combinations(l, i))) # all
|
133
|
-
|
134
|
-
|
135
|
-
|
119
|
+
|
120
|
+
def _f():
|
121
|
+
self.__predict(single_results)
|
122
|
+
|
123
|
+
if self.tracking_service:
|
124
|
+
self.tracking_service.run(run_name=self.prefix, description="***", func = _f , nested_run=False)
|
136
125
|
else:
|
137
126
|
self.__predict(single_results)
|
138
127
|
if combinations:
|
@@ -143,10 +132,17 @@ class MultiModalRunner:
|
|
143
132
|
def evaluate_combinations(self, single_results, combinations):
|
144
133
|
for combination in combinations:
|
145
134
|
combination_descriptor = '-'.join(combination)
|
146
|
-
if self.
|
147
|
-
|
135
|
+
if self.tracking_service:
|
136
|
+
def evaluate_combination(artifact_uri=None):
|
148
137
|
self.__evaluate_combinations(
|
149
|
-
single_results, combination, combination_descriptor,
|
138
|
+
single_results, combination, combination_descriptor, artifact_uri
|
139
|
+
)
|
140
|
+
|
141
|
+
self.tracking_service.run(run_name=combination_descriptor, description="***", nested_run=True, func=evaluate_combination)
|
142
|
+
|
143
|
+
# with mlflow.start_run(run_name=combination_descriptor, description="***", nested=True) as combination_run:
|
144
|
+
# self.__evaluate_combinations(
|
145
|
+
# single_results, combination, combination_descriptor, combination_run.info.artifact_uri)
|
150
146
|
else:
|
151
147
|
self.__evaluate_combinations(
|
152
148
|
single_results, combination, combination_descriptor, None)
|
@@ -159,8 +155,8 @@ class MultiModalRunner:
|
|
159
155
|
prediction = utils.to_one_hot_encode(prediction)
|
160
156
|
logs, metrics = evaluate(
|
161
157
|
actual=self.y_test_label, pred=prediction, info=combination_descriptor)
|
162
|
-
if self.
|
163
|
-
|
158
|
+
if self.tracking_service:
|
159
|
+
self.tracking_service.log_metrics(logs)
|
164
160
|
metrics.format_float()
|
165
161
|
# TODO path bulunamadı hatası aldık
|
166
162
|
if artifact_uri:
|