ddi-fw 0.0.214__py3-none-any.whl → 0.0.215__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -20,160 +20,160 @@ def create_connection(db_file=r"./event.db"):
20
20
  return conn
21
21
 
22
22
 
23
- def select_all_drugs(conn):
24
- cur = conn.cursor()
25
- cur.execute(
26
- '''select "index", id, name, target, enzyme, pathway, smile from drug''')
27
- rows = cur.fetchall()
28
- return rows
29
-
30
-
31
- def select_all_drugs_as_dataframe(conn):
32
- headers = ['index','id', 'name', 'target', 'enzyme', 'pathway', 'smile']
33
- rows = select_all_drugs(conn)
34
- df = pd.DataFrame(columns=headers, data=rows)
35
- df['enzyme'] = df['enzyme'].apply(lambda x: x.split('|'))
36
- df['target'] = df['target'].apply(lambda x: x.split('|'))
37
- df['pathway'] = df['pathway'].apply(lambda x: x.split('|'))
38
- df['smile'] = df['smile'].apply(lambda x: x.split('|'))
39
- return df
40
-
41
-
42
- def select_all_events(conn):
43
- """
44
- Query all rows in the event table
45
- :param conn: the Connection object
46
- :return:
47
- """
48
- cur = conn.cursor()
49
- cur.execute("select * from event")
50
-
51
- rows = cur.fetchall()
52
- return rows
53
-
54
-
55
- def select_all_events_as_dataframe(conn):
56
- headers = ["index", "id1", "name1", "id2", "name2", "event_category"]
57
- rows = select_all_events(conn)
58
- return pd.DataFrame(columns=headers, data=rows)
59
-
60
-
61
- def select_events_with_category(conn):
62
- sql = '''select id1, name1, id2, name2, mechanism || ' ' ||action from event ev
63
- join extraction ex
64
- on ev.name1 = ex.drugA and ev.name2 = ex.drugB
65
- union
66
- select id1, name1, id2, name2, mechanism || ' ' ||action from event ev
67
- join extraction ex
68
- on ev.name1 = ex.drugB and ev.name2 = ex.drugA
69
- '''
70
- cur = conn.cursor()
71
- cur.execute(sql)
72
-
73
- rows = cur.fetchall()
74
-
75
- headers = ['id1', 'name1', 'id2', 'name2', 'event_category']
76
- return pd.DataFrame(columns=headers, data=rows)
77
-
78
-
79
- def select_all_interactions_tuple_as_dataframe(conn):
80
- cur = conn.cursor()
81
- cur.execute("select id1, id2 from event")
82
- rows = cur.fetchall()
83
- headers = ['id1', 'id2']
84
-
85
- return pd.DataFrame(columns=headers, data=rows)
86
-
87
-
88
- def select_ddi_pairs(conn):
89
- cur = conn.cursor()
90
- cur.execute('''
91
- select d1.[index] as Drug1Index, d2.[index] as Drug2Index, 1 from event e
92
- join drug d1 on e.id1 = d1.id
93
- join drug d2 on e.id2 = d2.id
94
- ''')
95
- rows = cur.fetchall()
96
- return rows
97
-
98
-
99
- def select_ddi_pairs_as_dataframe(conn):
100
- headers = ["Drug1Index", "Drug2Index", "Interaction"]
101
- rows = select_ddi_pairs(conn)
102
- return pd.DataFrame(columns=headers, data=rows)
103
-
104
-
105
- def get_interactions(conn):
106
- cur = conn.cursor()
107
- cur.execute('''
108
- select
109
- drug_1_id,
110
- drug_1,
111
- drug_2_id,
112
- drug_2,
113
- mechanism_action,
114
- interaction,
115
- masked_interaction
116
- from _Interactions
117
- ''')
118
-
119
- rows = cur.fetchall()
120
-
121
- headers = ['id1', 'name1', 'id2', 'name2',
122
- 'event_category', 'interaction', 'masked_interaction']
123
- df = pd.DataFrame(columns=headers, data=rows)
124
- return df
125
-
126
-
127
- def get_extended_version(conn):
128
- cur = conn.cursor()
129
- cur.execute('''
130
- select
131
- _Drugs."index",
132
- drugbank_id,
133
- _Drugs.name,
134
- description,
135
- synthesis_reference,
136
- indication,
137
- pharmacodynamics,
138
- mechanism_of_action,
139
- toxicity,
140
- metabolism,
141
- absorption,
142
- half_life,
143
- protein_binding,
144
- route_of_elimination,
145
- volume_of_distribution,
146
- clearance,
147
- smiles,
148
- smiles_morgan_fingerprint,
149
- enzymes_polypeptides,
150
- targets_polypeptides
23
+ # def select_all_drugs(conn):
24
+ # cur = conn.cursor()
25
+ # cur.execute(
26
+ # '''select "index", id, name, target, enzyme, pathway, smile from drug''')
27
+ # rows = cur.fetchall()
28
+ # return rows
29
+
30
+
31
+ # def select_all_drugs_as_dataframe(conn):
32
+ # headers = ['index','id', 'name', 'target', 'enzyme', 'pathway', 'smile']
33
+ # rows = select_all_drugs(conn)
34
+ # df = pd.DataFrame(columns=headers, data=rows)
35
+ # df['enzyme'] = df['enzyme'].apply(lambda x: x.split('|'))
36
+ # df['target'] = df['target'].apply(lambda x: x.split('|'))
37
+ # df['pathway'] = df['pathway'].apply(lambda x: x.split('|'))
38
+ # df['smile'] = df['smile'].apply(lambda x: x.split('|'))
39
+ # return df
40
+
41
+
42
+ # def select_all_events(conn):
43
+ # """
44
+ # Query all rows in the event table
45
+ # :param conn: the Connection object
46
+ # :return:
47
+ # """
48
+ # cur = conn.cursor()
49
+ # cur.execute("select * from event")
50
+
51
+ # rows = cur.fetchall()
52
+ # return rows
53
+
54
+
55
+ # def select_all_events_as_dataframe(conn):
56
+ # headers = ["index", "id1", "name1", "id2", "name2", "event_category"]
57
+ # rows = select_all_events(conn)
58
+ # return pd.DataFrame(columns=headers, data=rows)
59
+
60
+
61
+ # def select_events_with_category(conn):
62
+ # sql = '''select id1, name1, id2, name2, mechanism || ' ' ||action from event ev
63
+ # join extraction ex
64
+ # on ev.name1 = ex.drugA and ev.name2 = ex.drugB
65
+ # union
66
+ # select id1, name1, id2, name2, mechanism || ' ' ||action from event ev
67
+ # join extraction ex
68
+ # on ev.name1 = ex.drugB and ev.name2 = ex.drugA
69
+ # '''
70
+ # cur = conn.cursor()
71
+ # cur.execute(sql)
72
+
73
+ # rows = cur.fetchall()
74
+
75
+ # headers = ['id1', 'name1', 'id2', 'name2', 'event_category']
76
+ # return pd.DataFrame(columns=headers, data=rows)
77
+
78
+
79
+ # def select_all_interactions_tuple_as_dataframe(conn):
80
+ # cur = conn.cursor()
81
+ # cur.execute("select id1, id2 from event")
82
+ # rows = cur.fetchall()
83
+ # headers = ['id1', 'id2']
84
+
85
+ # return pd.DataFrame(columns=headers, data=rows)
86
+
87
+
88
+ # def select_ddi_pairs(conn):
89
+ # cur = conn.cursor()
90
+ # cur.execute('''
91
+ # select d1.[index] as Drug1Index, d2.[index] as Drug2Index, 1 from event e
92
+ # join drug d1 on e.id1 = d1.id
93
+ # join drug d2 on e.id2 = d2.id
94
+ # ''')
95
+ # rows = cur.fetchall()
96
+ # return rows
97
+
98
+
99
+ # def select_ddi_pairs_as_dataframe(conn):
100
+ # headers = ["Drug1Index", "Drug2Index", "Interaction"]
101
+ # rows = select_ddi_pairs(conn)
102
+ # return pd.DataFrame(columns=headers, data=rows)
103
+
104
+
105
+ # def get_interactions(conn):
106
+ # cur = conn.cursor()
107
+ # cur.execute('''
108
+ # select
109
+ # drug_1_id,
110
+ # drug_1,
111
+ # drug_2_id,
112
+ # drug_2,
113
+ # mechanism_action,
114
+ # interaction,
115
+ # masked_interaction
116
+ # from _Interactions
117
+ # ''')
118
+
119
+ # rows = cur.fetchall()
120
+
121
+ # headers = ['id1', 'name1', 'id2', 'name2',
122
+ # 'event_category', 'interaction', 'masked_interaction']
123
+ # df = pd.DataFrame(columns=headers, data=rows)
124
+ # return df
125
+
126
+
127
+ # def get_extended_version(conn):
128
+ # cur = conn.cursor()
129
+ # cur.execute('''
130
+ # select
131
+ # _Drugs."index",
132
+ # drugbank_id,
133
+ # _Drugs.name,
134
+ # description,
135
+ # synthesis_reference,
136
+ # indication,
137
+ # pharmacodynamics,
138
+ # mechanism_of_action,
139
+ # toxicity,
140
+ # metabolism,
141
+ # absorption,
142
+ # half_life,
143
+ # protein_binding,
144
+ # route_of_elimination,
145
+ # volume_of_distribution,
146
+ # clearance,
147
+ # smiles,
148
+ # smiles_morgan_fingerprint,
149
+ # enzymes_polypeptides,
150
+ # targets_polypeptides
151
151
 
152
- from drug
153
- join _Drugs on drug.id = _Drugs.drugbank_id
154
- where
155
- targets_polypeptides is not null and
156
- enzymes_polypeptides is not null and
157
- smiles_morgan_fingerprint is not null
158
- ''')
159
- # pathway is absent
160
-
161
- rows = cur.fetchall()
162
- headers = ['index', 'id', 'name', 'description', 'synthesis_reference', 'indication', 'pharmacodynamics', 'mechanism_of_action', 'toxicity', 'metabolism', 'absorption', 'half_life',
163
- 'protein_binding', 'route_of_elimination', 'volume_of_distribution', 'clearance', 'smiles_notation', 'smile', 'enzyme', 'target']
164
- df = pd.DataFrame(columns=headers, data=rows)
165
- df['smile'] = df['smile'].apply(lambda x:
166
- np.fromstring(
167
- x.replace(
168
- '\n', '')
169
- .replace('[', '')
170
- .replace(']', '')
171
- .replace(' ', ' '), sep=','))
172
- df['enzyme'] = df['enzyme'].apply(
173
- lambda x: x.split('|'))
174
- df['target'] = df['target'].apply(
175
- lambda x: x.split('|'))
176
- return df
152
+ # from drug
153
+ # join _Drugs on drug.id = _Drugs.drugbank_id
154
+ # where
155
+ # targets_polypeptides is not null and
156
+ # enzymes_polypeptides is not null and
157
+ # smiles_morgan_fingerprint is not null
158
+ # ''')
159
+ # # pathway is absent
160
+
161
+ # rows = cur.fetchall()
162
+ # headers = ['index', 'id', 'name', 'description', 'synthesis_reference', 'indication', 'pharmacodynamics', 'mechanism_of_action', 'toxicity', 'metabolism', 'absorption', 'half_life',
163
+ # 'protein_binding', 'route_of_elimination', 'volume_of_distribution', 'clearance', 'smiles_notation', 'smile', 'enzyme', 'target']
164
+ # df = pd.DataFrame(columns=headers, data=rows)
165
+ # df['smile'] = df['smile'].apply(lambda x:
166
+ # np.fromstring(
167
+ # x.replace(
168
+ # '\n', '')
169
+ # .replace('[', '')
170
+ # .replace(']', '')
171
+ # .replace(' ', ' '), sep=','))
172
+ # df['enzyme'] = df['enzyme'].apply(
173
+ # lambda x: x.split('|'))
174
+ # df['target'] = df['target'].apply(
175
+ # lambda x: x.split('|'))
176
+ # return df
177
177
 
178
178
 
179
179
  # SELECT
@@ -190,15 +190,15 @@ def get_extended_version(conn):
190
190
  # where LENGTH(masked_interaction) = LENGTH(REPLACE(masked_interaction, 'DRUG', ''))
191
191
  # or LENGTH(masked_interaction) = LENGTH(REPLACE(masked_interaction, 'DRUG', '')) + 4
192
192
 
193
- if __name__ == "__main__":
194
- conn = create_connection(r"./event-extended.db")
195
- extended_version_df = get_extended_version(conn)
193
+ # if __name__ == "__main__":
194
+ # conn = create_connection(r"./event-extended.db")
195
+ # extended_version_df = get_extended_version(conn)
196
196
 
197
- df = select_all_events_as_dataframe(conn)
198
- print(df.head())
197
+ # df = select_all_events_as_dataframe(conn)
198
+ # print(df.head())
199
199
 
200
- events_with_category_df = select_events_with_category(conn)
201
- print(events_with_category_df.head())
200
+ # events_with_category_df = select_events_with_category(conn)
201
+ # print(events_with_category_df.head())
202
202
 
203
- u = events_with_category_df['event_category'].unique()
204
- print(len(u))
203
+ # u = events_with_category_df['event_category'].unique()
204
+ # print(len(u))
@@ -96,10 +96,10 @@ class DDIMDLDataset(BaseDataset,TextDatasetMixin):
96
96
  logger.info(f'{self.dataset_name} is initialized')
97
97
 
98
98
  def load_drugs_and_events(self):
99
- self.drugs_df = self.__select_all_drugs_as_dataframe__()
100
- self.ddis_df = self.__select_all_events__()
99
+ self.drugs_df = self.__select_all_drugs_as_dataframe()
100
+ self.ddis_df = self.__select_all_events()
101
101
 
102
- def __select_all_drugs_as_dataframe__(self):
102
+ def __select_all_drugs_as_dataframe(self):
103
103
  headers = ['index', 'id', 'name',
104
104
  'target', 'enzyme', 'pathway', 'smile']
105
105
  if self._conn is None:
@@ -117,7 +117,7 @@ class DDIMDLDataset(BaseDataset,TextDatasetMixin):
117
117
 
118
118
  return df
119
119
 
120
- def __select_all_events__(self):
120
+ def __select_all_events(self):
121
121
  if self._conn is None:
122
122
  raise Exception("There is no connection")
123
123
  cur = self._conn.cursor()
@@ -221,16 +221,16 @@ class DDIMDLDataset(BaseDataset,TextDatasetMixin):
221
221
  lambda_fnc, args=(value,), axis=1)
222
222
  self.columns.append(key)
223
223
  print(self.ddis_df[key].head())
224
-
225
- if self.embedding_dict is not None:
226
- for embedding_column in self.embedding_columns:
227
- print(f"concat {embedding_column} embeddings")
228
- embeddings_after_pooling = {k: self.pooling_strategy.apply(
229
- v) for k, v in self.embedding_dict[embedding_column].items()}
230
- # column_embeddings_dict = embedding_values[embedding_column]
231
- self.ddis_df[embedding_column+'_embedding'] = self.ddis_df.apply(
232
- x_fnc, args=(embeddings_after_pooling,), axis=1)
233
- self.columns.append(embedding_column+'_embedding')
224
+ if isinstance(self, TextDatasetMixin):
225
+ if self.embedding_dict is not None:
226
+ for embedding_column in self.embedding_columns:
227
+ print(f"concat {embedding_column} embeddings")
228
+ embeddings_after_pooling = {k: self.pooling_strategy.apply(
229
+ v) for k, v in self.embedding_dict[embedding_column].items()}
230
+ # column_embeddings_dict = embedding_values[embedding_column]
231
+ self.ddis_df[embedding_column+'_embedding'] = self.ddis_df.apply(
232
+ x_fnc, args=(embeddings_after_pooling,), axis=1)
233
+ self.columns.append(embedding_column+'_embedding')
234
234
 
235
235
  dataframe = self.ddis_df.copy()
236
236
  if not isinstance(classes, (list, pd.Series, np.ndarray)):
@@ -90,8 +90,8 @@ class DDIMDLDatasetV2(BaseDataset):
90
90
  db = HERE.joinpath('data/event.db')
91
91
  conn = create_connection(db)
92
92
  print("db prep")
93
- self.drugs_df = self.__select_all_drugs_as_dataframe__(conn)
94
- self.ddis_df = self.__select_all_events__(conn)
93
+ self.drugs_df = self.__select_all_drugs_as_dataframe(conn)
94
+ self.ddis_df = self.__select_all_events(conn)
95
95
  print("db bitti")
96
96
  self.index_path = kwargs.get('index_path')
97
97
 
@@ -121,7 +121,7 @@ class DDIMDLDatasetV2(BaseDataset):
121
121
  # print(self.ddis_df[key].head())
122
122
  # print("init finished")
123
123
 
124
- def __select_all_drugs_as_dataframe__(self, conn):
124
+ def __select_all_drugs_as_dataframe(self, conn):
125
125
  headers = ['index', 'id', 'name',
126
126
  'target', 'enzyme', 'pathway', 'smile']
127
127
  cur = conn.cursor()
@@ -137,7 +137,7 @@ class DDIMDLDatasetV2(BaseDataset):
137
137
 
138
138
  return df
139
139
 
140
- def __select_all_events__(self, conn):
140
+ def __select_all_events(self, conn):
141
141
  """
142
142
  Query all rows in the event table
143
143
  :param conn: the Connection object
@@ -9,7 +9,6 @@ from ddi_fw.utils import ZipHelper
9
9
  from .. import BaseDataset
10
10
  from ddi_fw.langchain.embeddings import PoolingStrategy
11
11
  from ..db_utils import create_connection
12
- # from ..db_utils import create_connection, select_all_drugs_as_dataframe, select_events_with_category
13
12
 
14
13
  HERE = pathlib.Path(__file__).resolve().parent
15
14
  list_of_embedding_columns = ['all_text', 'description',
@@ -120,6 +120,8 @@ class MultiPipeline():
120
120
  columns = config.get("columns")
121
121
  ner_data_file = config.get("ner_data_file")
122
122
  ner_threshold = config.get("ner_threshold")
123
+ ner_min_threshold_dict = config.get("ner_min_threshold_dict")
124
+ ner_max_threshold_dict = config.get("ner_max_threshold_dict")
123
125
  column_embedding_configs = config.get("column_embedding_configs")
124
126
  vector_db_persist_directory = config.get("vector_db_persist_directory")
125
127
  vector_db_collection_name = config.get("vector_db_collection_name")
@@ -170,10 +172,14 @@ class MultiPipeline():
170
172
  experiment_tags=experiment_tags,
171
173
  tracking_uri=tracking_uri,
172
174
  dataset_type=dataset_type,
175
+ dataset_splitter_type=dataset_splitter_type,
173
176
  umls_code_types = None,
174
177
  text_types = None,
175
- columns=['tui', 'cui', 'entities'],
178
+ min_threshold_dict=ner_min_threshold_dict,
179
+ max_threshold_dict=ner_max_threshold_dict,
180
+ columns=columns,
176
181
  ner_data_file=ner_data_file,
182
+ default_model=default_model,
177
183
  multi_modal= multi_modal
178
184
  )
179
185
 
@@ -0,0 +1,231 @@
1
+ import json
2
+ from typing import Optional
3
+ from ddi_fw.pipeline.pipeline import Pipeline
4
+ from ddi_fw.pipeline.ner_pipeline import NerParameterSearch
5
+ import importlib
6
+
7
+
8
+ def load_config(file_path):
9
+ with open(file_path, 'r') as file:
10
+ config = json.load(file)
11
+ return config
12
+
13
+
14
+ def get_import(full_path_of_import):
15
+ """Dynamically imports an object from a module given its full path.
16
+
17
+ Args:
18
+ full_path_of_import (str): The full path of the import (e.g., 'module.submodule.ClassName').
19
+
20
+ Returns:
21
+ object: The imported object.
22
+
23
+ Raises:
24
+ ImportError: If the module cannot be imported.
25
+ AttributeError: If the attribute does not exist in the module.
26
+ """
27
+ if not full_path_of_import:
28
+ raise ValueError("The import path cannot be empty.")
29
+
30
+ parts = full_path_of_import.split('.')
31
+ import_name = parts[-1]
32
+ module_name = ".".join(parts[:-1]) if len(parts) > 1 else ""
33
+
34
+ try:
35
+ module = importlib.import_module(module_name)
36
+ return getattr(module, import_name)
37
+ except ModuleNotFoundError as e:
38
+ raise ImportError(f"Module '{module_name}' could not be found.") from e
39
+ except AttributeError as e:
40
+ raise AttributeError(
41
+ f"'{module_name}' has no attribute '{import_name}'") from e
42
+
43
+
44
+ class MultiPipeline():
45
+ # def __init__(self, experiments_config_file, experiments_config):
46
+ # if experiments_config_file is None and experiments_config is None:
47
+ # raise ValueError("Either experiments_config_file or experiments_config must be provided.")
48
+ # if experiments_config_file is not None and experiments_config is not None:
49
+ # raise ValueError("Only one of experiments_config_file or experiments_config should be provided.")
50
+ # if experiments_config_file is not None:
51
+ # self.experiments_config = load_config(experiments_config_file)
52
+ # else:
53
+ # self.experiments_config = experiments_config
54
+ # self.items = []
55
+ # self.pipeline_resuts = dict()
56
+
57
+ def __init__(self, experiments_config_file: Optional[str] = None, experiments_config: Optional[dict] = None):
58
+ """
59
+ Initialize the MultiPipeline.
60
+
61
+ Args:
62
+ experiments_config_file (str, optional): Path to the experiments configuration file.
63
+ experiments_config (dict, optional): Dictionary containing the experiments configuration.
64
+
65
+ Raises:
66
+ ValueError: If neither or both of the parameters are provided.
67
+ """
68
+ self.experiments_config = self._validate_and_load_config(experiments_config_file, experiments_config)
69
+ self.items = []
70
+ # self.pipeline_results = {}
71
+ self.pipeline_resuts = dict()
72
+
73
+ def _validate_and_load_config(self, experiments_config_file: Optional[str], experiments_config: Optional[dict]) -> dict:
74
+ """
75
+ Validate and load the experiments configuration.
76
+
77
+ Args:
78
+ experiments_config_file (str, optional): Path to the experiments configuration file.
79
+ experiments_config (dict, optional): Dictionary containing the experiments configuration.
80
+
81
+ Returns:
82
+ dict: The loaded experiments configuration.
83
+
84
+ Raises:
85
+ ValueError: If neither or both of the parameters are provided.
86
+ """
87
+ if experiments_config_file is None and experiments_config is None:
88
+ raise ValueError("Either 'experiments_config_file' or 'experiments_config' must be provided.")
89
+ if experiments_config_file is not None and experiments_config is not None:
90
+ raise ValueError("Only one of 'experiments_config_file' or 'experiments_config' should be provided.")
91
+
92
+ if experiments_config_file is not None:
93
+ try:
94
+ config = load_config(experiments_config_file)
95
+ except FileNotFoundError:
96
+ raise FileNotFoundError(f"Configuration file '{experiments_config_file}' not found.")
97
+ else:
98
+ config = experiments_config
99
+ if config is None:
100
+ raise ValueError("Configuration cannot be None.")
101
+ if not isinstance(config, dict):
102
+ raise ValueError("Configuration must be a dictionary.")
103
+ # if "experiments" not in config:
104
+ # raise ValueError("Configuration must contain 'experiments' key.")
105
+ return config
106
+
107
+ def __create_pipeline(self, config):
108
+ type = config.get("type")
109
+ library = config.get("library")
110
+ experiment_name = config.get("experiment_name")
111
+ experiment_description = config.get("experiment_description")
112
+ experiment_tags = config.get("experiment_tags")
113
+
114
+ # Tracking configuration
115
+ tracking_config = config.get("tracking_config", {})
116
+ tracking_library = tracking_config.get("library")
117
+ use_tracking = tracking_config.get("use_tracking", False)
118
+ tracking_params = tracking_config.get("params", {}).get(tracking_library, {})
119
+
120
+ # tracking_uri = config.get("tracking_uri")
121
+ # artifact_location = config.get("artifact_location")
122
+
123
+ # Dataset configuration
124
+ dataset_config = config.get("dataset", {})
125
+ dataset_type = get_import(dataset_config.get("dataset_type"))
126
+ dataset_splitter_type = get_import(dataset_config.get("dataset_splitter_type"))
127
+ columns = dataset_config.get("columns", [])
128
+ additional_config = dataset_config.get("additional_config", {})
129
+
130
+ # Vector database configuration
131
+ vector_database = config.get("vector_databases", {})
132
+ vector_db_persist_directory = None
133
+ vector_db_collection_name = None
134
+ embedding_pooling_strategy = None
135
+ if vector_database:
136
+ vector_db_persist_directory = vector_database.get("vector_db_persist_directory")
137
+ vector_db_collection_name = vector_database.get("vector_db_collection_name")
138
+ embedding_pooling_strategy = get_import(vector_database.get("embedding_pooling_strategy"))
139
+ column_embedding_configs = vector_database.get("column_embedding_configs")
140
+
141
+ # Combination strategy
142
+ combination_strategy_config = config.get("combination_strategy", {})
143
+ combination_type = get_import(combination_strategy_config.get("type")) if combination_strategy_config else None
144
+ kwargs_combination_params = combination_strategy_config.get("params", {})
145
+ combinations = combination_type(**kwargs_combination_params).generate() if combination_type else []
146
+
147
+ # Default model configuration
148
+ default_model_config = config.get("default_model", {})
149
+ default_model_type = get_import(default_model_config.get("model_type"))
150
+ default_model_params = default_model_config.get("params", {})
151
+
152
+ multi_modal = config.get("multi_modal")
153
+
154
+
155
+
156
+ #ner move it to related dataset
157
+
158
+ # ner_data_file = config.get("ner_data_file")
159
+ # ner_threshold = config.get("ner_threshold")
160
+
161
+
162
+ combination_type = None
163
+ kwargs_combination_params=None
164
+ if config.get("combination_strategy"):
165
+ combination_type = get_import(config.get("combination_strategy").get("type"))
166
+ kwargs_combination_params = config.get("combination_strategy").get("params")
167
+ combinations = []
168
+ if combination_type is not None:
169
+ combinations = combination_type(**kwargs_combination_params).generate()
170
+
171
+
172
+ pipeline = None
173
+ if type == "general":
174
+ pipeline = Pipeline(
175
+ library=library,
176
+ use_mlflow=use_mlflow,
177
+ experiment_name=experiment_name,
178
+ experiment_description=experiment_description,
179
+ experiment_tags=experiment_tags,
180
+ artifact_location=artifact_location,
181
+ tracking_uri=tracking_uri,
182
+ dataset_type=dataset_type,
183
+ dataset_splitter_type=dataset_splitter_type,
184
+ columns=columns,
185
+ column_embedding_configs=column_embedding_configs,
186
+ vector_db_persist_directory=vector_db_persist_directory,
187
+ vector_db_collection_name=vector_db_collection_name,
188
+ embedding_pooling_strategy_type=embedding_pooling_strategy,
189
+ ner_data_file=ner_data_file,
190
+ ner_threshold=ner_threshold,
191
+ combinations=combinations,
192
+ default_model=default_model,
193
+ multi_modal= multi_modal)
194
+ elif type== "ner_search":
195
+ pipeline = NerParameterSearch(
196
+ library=library,
197
+ experiment_name=experiment_name,
198
+ experiment_description=experiment_description,
199
+ experiment_tags=experiment_tags,
200
+ tracking_uri=tracking_uri,
201
+ dataset_type=dataset_type,
202
+ umls_code_types = None,
203
+ text_types = None,
204
+ columns=['tui', 'cui', 'entities'],
205
+ ner_data_file=ner_data_file,
206
+ multi_modal= multi_modal
207
+ )
208
+
209
+
210
+ return {
211
+ "name": experiment_name,
212
+ "library": library,
213
+ "pipeline": pipeline}
214
+
215
+ def build(self):
216
+ for config in self.experiments_config['experiments']:
217
+ item = self.__create_pipeline(config)
218
+ self.items.append(item)
219
+ return self
220
+
221
+ def run(self):
222
+ for item in self.items:
223
+ print(f"{item['name']} is running")
224
+ pipeline = item['pipeline']
225
+ pipeline.build()
226
+ result = pipeline.run()
227
+ self.pipeline_resuts[item['name']] = result
228
+ return self
229
+
230
+ def results(self):
231
+ return self.pipeline_resuts
@@ -1,139 +1,134 @@
1
1
  from collections import defaultdict
2
+ from typing import Any, Dict, List, Optional, Type
3
+ from itertools import product
2
4
  import numpy as np
5
+ import mlflow
6
+ from pydantic import BaseModel, Field, model_validator, root_validator, validator
3
7
  from ddi_fw.datasets.core import BaseDataset
8
+ from ddi_fw.datasets.dataset_splitter import DatasetSplitter
4
9
  from ddi_fw.vectorization.idf_helper import IDF
5
- from typing import Any, Dict, List, Optional
6
- from itertools import product
7
-
10
+ from ddi_fw.ner.ner import CTakesNER
8
11
  from ddi_fw.ml.ml_helper import MultiModalRunner
9
12
  from ddi_fw.utils.enums import DrugBankTextDataTypes, UMLSCodeTypes
10
- import mlflow
11
- from ddi_fw.ner.ner import CTakesNER
12
13
 
13
14
 
14
- def stack(df_column):
15
- return np.stack(df_column.values)
16
-
17
-
18
- class NerParameterSearch:
19
- def __init__(self,
20
- library,
21
- multi_modal,
22
- experiment_name,
23
- experiment_description,
24
- experiment_tags,
25
- tracking_uri,
26
- dataset_type: BaseDataset,
27
- ner_data_file,
28
- columns: list,
29
- umls_code_types: List[UMLSCodeTypes]|None,
30
- text_types:List[DrugBankTextDataTypes]|None,
31
- min_threshold_dict: Dict[str, float] = defaultdict(float),
32
- max_threshold_dict: Dict[str, float] = defaultdict(float),
33
- increase_step=0.5):
34
- self.library = library
35
- self.multi_modal = multi_modal
36
- self.experiment_name = experiment_name
37
- self.experiment_description = experiment_description
38
- self.experiment_tags = experiment_tags
39
- self.tracking_uri = tracking_uri
40
-
41
- self.dataset_type = dataset_type
42
- self.ner_data_file = ner_data_file
43
- self.columns = columns
44
- self.umls_code_types = umls_code_types
45
- self.text_types = text_types
46
-
47
- self.min_threshold_dict = min_threshold_dict
48
- self.max_threshold_dict = max_threshold_dict
49
- self.increase_step = increase_step
15
+ class NerParameterSearch(BaseModel):
16
+ library: str
17
+ default_model: Optional[Any] = None
18
+ multi_modal: Optional[Any] = None
19
+ experiment_name: str
20
+ experiment_description: Optional[str] = None
21
+ experiment_tags: Optional[Dict[str, Any]] = None
22
+ tracking_uri: str
23
+ dataset_type: Type[BaseDataset]
24
+ dataset_splitter_type: Type[DatasetSplitter] = DatasetSplitter
25
+ ner_data_file: Optional[str] = None
26
+ columns: List[str] = Field(default_factory=list)
27
+ umls_code_types: Optional[List[UMLSCodeTypes]] = None
28
+ text_types: Optional[List[DrugBankTextDataTypes]] = None
29
+ min_threshold_dict: Dict[str, float] = Field(default_factory=lambda: defaultdict(float))
30
+ max_threshold_dict: Dict[str, float] = Field(default_factory=lambda: defaultdict(float))
31
+ increase_step: float = 0.5
32
+
33
+ # Internal fields (not part of the input)
34
+ datasets: Dict[str, Any] = Field(default_factory=dict, exclude=True)
35
+ items: List[Any] = Field(default_factory=list, exclude=True)
36
+ ner_df: Optional[Any] = Field(default=None, exclude=True)
37
+ train_idx_arr: Optional[List[np.ndarray]] = Field(default=None, exclude=True)
38
+ val_idx_arr: Optional[List[np.ndarray]] = Field(default=None, exclude=True)
39
+ y_test_label: Optional[np.ndarray] = Field(default=None, exclude=True)
40
+
41
+ class Config:
42
+ arbitrary_types_allowed = True
43
+
44
+ # @root_validator(pre=True)
45
+ @model_validator(mode="before")
46
+ def validate_columns_and_thresholds(cls, values):
47
+ """Validate and initialize columns and thresholds."""
48
+ umls_code_types = values.get("umls_code_types")
49
+ text_types = values.get("text_types")
50
+ columns = values.get("columns", [])
51
+
52
+ if umls_code_types and text_types:
53
+ _umls_codes = [t.value[0] for t in umls_code_types]
54
+ _text_types = [t.value[0] for t in text_types]
55
+ _columns = [f"{item[0]}_{item[1]}" for item in product(_umls_codes, _text_types)]
56
+ columns.extend(_columns)
57
+
58
+ values["columns"] = columns
59
+ return values
50
60
 
51
61
  def build(self):
62
+ """Build the datasets and items for the parameter search."""
52
63
  if not isinstance(self.dataset_type, type):
53
64
  raise TypeError("self.dataset_type must be a class, not an instance")
54
- self.datasets = {}
55
- self.items = []
56
- # columns = ['tui', 'cui', 'entities']
57
- if self.umls_code_types is not None and self.text_types is not None:
58
- # add checking statements
59
- _umls_codes = [t.value[0] for t in self.umls_code_types]
60
- _text_types = [t.value[0] for t in self.text_types]
61
- _columns = [f'{item[0]}_{item[1]}' for item in product(
62
- _umls_codes, _text_types)]
63
- self.columns.extend(_columns)
64
- print(f'Columns: {self.columns}')
65
- self.ner_df = CTakesNER(df = None).load(
66
- filename=self.ner_data_file) if self.ner_data_file else None
67
65
 
66
+ # Load NER data
67
+ if self.ner_data_file:
68
+ self.ner_df = CTakesNER(df=None).load(filename=self.ner_data_file)
69
+
70
+ # Initialize thresholds if not provided
68
71
  if not self.min_threshold_dict or not self.max_threshold_dict:
69
- idf2 = IDF(self.ner_df, self.columns)
70
- idf2.calculate()
71
- # df = pd.DataFrame.from_dict(idf2.idf_scores)
72
- df = idf2.to_dataframe()
73
- import math
74
- self.min_threshold_dict = {key: math.floor(
75
- df.describe()[key]['min']) for key in df.describe().keys()}
76
- self.max_threshold_dict = {key: math.ceil(
77
- df.describe()[key]['max']) for key in df.describe().keys()}
78
-
79
- train_idx_arr, val_idx_arr = None, None
72
+ idf = IDF(self.ner_df, self.columns)
73
+ idf.calculate()
74
+ df = idf.to_dataframe()
75
+ self.min_threshold_dict = {key: np.floor(df.describe()[key]["min"]) for key in df.describe().keys()}
76
+ self.max_threshold_dict = {key: np.ceil(df.describe()[key]["max"]) for key in df.describe().keys()}
77
+
78
+ # Generate datasets and items
80
79
  for column in self.columns:
81
80
  min_threshold = self.min_threshold_dict[column]
82
81
  max_threshold = self.max_threshold_dict[column]
83
- kwargs = {}
84
- kwargs['threshold_method'] = 'idf'
85
- kwargs['tui_threshold'] = 0
86
- kwargs['cui_threshold'] = 0
87
- kwargs['entities_threshold'] = 0
82
+ kwargs = {
83
+ "threshold_method": "idf",
84
+ "tui_threshold": 0,
85
+ "cui_threshold": 0,
86
+ "entities_threshold": 0,
87
+ }
88
88
 
89
89
  for threshold in np.arange(min_threshold, max_threshold, self.increase_step):
90
- print(threshold)
91
- if column.startswith('tui'):
92
- kwargs['tui_threshold'] = threshold
93
- if column.startswith('cui'):
94
- kwargs['cui_threshold'] = threshold
95
- if column.startswith('entities'):
96
- kwargs['entities_threshold'] = threshold
90
+ if column.startswith("tui"):
91
+ kwargs["tui_threshold"] = threshold
92
+ if column.startswith("cui"):
93
+ kwargs["cui_threshold"] = threshold
94
+ if column.startswith("entities"):
95
+ kwargs["entities_threshold"] = threshold
96
+
97
97
  dataset = self.dataset_type(
98
- # chemical_property_columns=[],
99
- # embedding_columns=[],
100
- # ner_columns=[column],
101
98
  columns=[column],
102
99
  ner_df=self.ner_df,
103
- embedding_size=None,
104
- embedding_dict=None,
105
- embeddings_pooling_strategy=None,
106
- **kwargs)
107
-
108
- # train_idx_arr, val_idx_arr bir kez hesaplanması yeterli aslında
100
+ dataset_splitter_type=self.dataset_splitter_type,
101
+ **kwargs,
102
+ )
109
103
  dataset.load()
110
104
  group_items = dataset.produce_inputs()
105
+
111
106
  for item in group_items:
112
- # item[0] = f'threshold_{threshold}_{item[0]}'
113
- item[0] = f'threshold_{item[0]}_{threshold}'
114
- self.datasets[item[0]] = dataset.ddis_df
107
+ item[0] = f"threshold_{item[0]}_{threshold}"
108
+ self.datasets[item[0]] = dataset
115
109
 
116
110
  self.items.extend(group_items)
111
+
117
112
  self.y_test_label = self.items[0][4]
118
113
  self.train_idx_arr = dataset.train_idx_arr
119
114
  self.val_idx_arr = dataset.val_idx_arr
120
115
 
121
- def run(self, model_func, batch_size=128, epochs=100):
116
+ def run(self):
117
+ """Run the parameter search."""
122
118
  mlflow.set_tracking_uri(self.tracking_uri)
123
119
 
124
- if mlflow.get_experiment_by_name(self.experiment_name) == None:
120
+ if mlflow.get_experiment_by_name(self.experiment_name) is None:
125
121
  mlflow.create_experiment(self.experiment_name)
122
+ if self.experiment_tags:
126
123
  mlflow.set_experiment_tags(self.experiment_tags)
127
124
  mlflow.set_experiment(self.experiment_name)
128
125
 
129
- y_test_label = self.items[0][4]
130
126
  multi_modal_runner = MultiModalRunner(
131
- library=self.library, multi_modal=self.multi_modal)
132
- # multi_modal_runner = MultiModalRunner(
133
- # library=self.library, model_func=model_func, batch_size=batch_size, epochs=epochs)
134
- multi_modal_runner.set_data(
135
- self.items, self.train_idx_arr, self.val_idx_arr, y_test_label)
127
+ library=self.library,
128
+ multi_modal=self.multi_modal,
129
+ default_model=self.default_model,
130
+ use_mlflow=True,
131
+ )
132
+ multi_modal_runner.set_data(self.items, self.train_idx_arr, self.val_idx_arr, self.y_test_label)
136
133
  result = multi_modal_runner.predict()
137
-
138
-
139
- return result
134
+ return result
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ddi_fw
3
- Version: 0.0.214
3
+ Version: 0.0.215
4
4
  Summary: Do not use :)
5
5
  Author-email: Kıvanç Bayraktar <bayraktarkivanc@gmail.com>
6
6
  Maintainer-email: Kıvanç Bayraktar <bayraktarkivanc@gmail.com>
@@ -1,9 +1,9 @@
1
1
  ddi_fw/datasets/__init__.py,sha256=_I3iDHARwzmg7_EL5XKtB_TgG1yAkLSOVTujLL9Wz9Q,280
2
2
  ddi_fw/datasets/core.py,sha256=ciIOma5--rArgpCjtALpgBBZaKk2oI6rJeGCs05h4iU,16685
3
3
  ddi_fw/datasets/dataset_splitter.py,sha256=8H8uZTAf8N9LUZeSeHOMawtJFJhnDgUUqFcnl7dquBQ,1672
4
- ddi_fw/datasets/db_utils.py,sha256=OTsa3d-Iic7z3HmzSQK9UigedRbHDxYChJk0s4GfLnw,6191
4
+ ddi_fw/datasets/db_utils.py,sha256=xRj28U_uXTRPHcz3yIICczFUHXUPiAOZtAj5BM6kH44,6465
5
5
  ddi_fw/datasets/setup_._py,sha256=khYVJuW5PlOY_i_A16F3UbSZ6s6o_ljw33Byw3C-A8E,1047
6
- ddi_fw/datasets/ddi_mdl/base.py,sha256=afk5ToGSnHQ-N7dnhiJYIiFTBM08vwEJgb6mwiacY1w,10409
6
+ ddi_fw/datasets/ddi_mdl/base.py,sha256=Vvyzxd2BnFK9Bn2mn-3aS5ZczlPElQ0-TKMAqgkyJiI,10483
7
7
  ddi_fw/datasets/ddi_mdl/debug.log,sha256=eWz05j8RFqZuHFDTCF7Rck5w4rvtTanFN21iZsgxO7Y,115
8
8
  ddi_fw/datasets/ddi_mdl/readme.md,sha256=WC6lpmsEKvIISnZqENY7TWtzCQr98HPpE3oRsBl8pIw,625
9
9
  ddi_fw/datasets/ddi_mdl/data/event.db,sha256=cmlSsf9MYjRzqR-mw3cUDnTnfT6FkpOG2yCl2mMwwew,30580736
@@ -31,7 +31,7 @@ ddi_fw/datasets/ddi_mdl/indexes_old/validation_fold_1.txt,sha256=KuRIXhQ4uUrYK0d
31
31
  ddi_fw/datasets/ddi_mdl/indexes_old/validation_fold_2.txt,sha256=bkZuigOiSJleKx4hJ7L0aF0ltMgaI7wfrODWiJ1ppkY,69374
32
32
  ddi_fw/datasets/ddi_mdl/indexes_old/validation_fold_3.txt,sha256=HNBs4vHhyPepnZyeeqRWfrr25heB9CKIo9-5Xh02Pyk,69393
33
33
  ddi_fw/datasets/ddi_mdl/indexes_old/validation_fold_4.txt,sha256=-pphbOlEyKD2xedpzkhWvPExDm43FKZ2RSCiie-ZHHo,69356
34
- ddi_fw/datasets/ddi_mdl_text/base.py,sha256=JCQvqJPAaOks0VTNRf_birph70CrHoIow5Z3rfX_cdw,6721
34
+ ddi_fw/datasets/ddi_mdl_text/base.py,sha256=UHhrb8ts3IOpl7MdxopBjiKK5xGfzMysc-OdGyDKpCM,6713
35
35
  ddi_fw/datasets/ddi_mdl_text/data/event.db,sha256=cmlSsf9MYjRzqR-mw3cUDnTnfT6FkpOG2yCl2mMwwew,30580736
36
36
  ddi_fw/datasets/ddi_mdl_text/indexes/test_indexes.txt,sha256=XVlDqYATckrQwNSXqMSKVBqyoN_Hg8SK6CL-XMdLADY,102176
37
37
  ddi_fw/datasets/ddi_mdl_text/indexes/train_fold_0.txt,sha256=YDQwMNEpEjkOJiUtX_BoymdezE34W3mj1Q0oAYG3mIs,325015
@@ -46,7 +46,7 @@ ddi_fw/datasets/ddi_mdl_text/indexes/validation_fold_2.txt,sha256=fFJbN0DbKH4mve
46
46
  ddi_fw/datasets/ddi_mdl_text/indexes/validation_fold_3.txt,sha256=NhiLF_5INQCpjOlE-RIxDKy7rYwksLdx60L6HCmDKoY,81247
47
47
  ddi_fw/datasets/ddi_mdl_text/indexes/validation_fold_4.txt,sha256=bPvMCJVy7jtcaYbR-5bmdB6s7gT8NSfK2wDC7iJ0O10,81308
48
48
  ddi_fw/datasets/mdf_sa_ddi/__init__.py,sha256=UEFBM92y2aJjlMJw4Jx405tOAwJ88r_nHAVgAszSjuo,68
49
- ddi_fw/datasets/mdf_sa_ddi/base.py,sha256=kYNmtg-s0V7mP-wjLMaAstNCG3vckMPQSE651RA_LAE,6502
49
+ ddi_fw/datasets/mdf_sa_ddi/base.py,sha256=ILdvu7pBMazt-FxRWzIaqO2PmbkyooEOT3U9vSoV3PY,6398
50
50
  ddi_fw/datasets/mdf_sa_ddi/df_extraction_cleanxiaoyu50.csv,sha256=EOOLF_0vVVzShoofcGYlOzpztlM1m9jJdftepHicix4,25787699
51
51
  ddi_fw/datasets/mdf_sa_ddi/drug_information_del_noDDIxiaoyu50.csv,sha256=lpuMz5KxPsG6MKNuIIUmT5cZquWHQiIao8tXlmOHzq8,381321
52
52
  ddi_fw/datasets/mdf_sa_ddi/mdf-sa-ddi.zip,sha256=DfN8mczGvWba2y45cPqtWtXjUDXy49VOtRfpcb0tn8c,4382827
@@ -83,8 +83,9 @@ ddi_fw/ner/mmlrestclient.py,sha256=NZta7m2Qm6I_qtVguMZhqtAUjVBmmXn0-TMnsNp0jpg,6
83
83
  ddi_fw/ner/ner.py,sha256=FHyyX53Xwpdw8Hec261dyN88yD7Z9LmJua2mIrQLguI,17967
84
84
  ddi_fw/pipeline/__init__.py,sha256=tKDM_rW4vPjlYTeOkNgi9PujDzb4e9O3LK1w5wqnebw,212
85
85
  ddi_fw/pipeline/multi_modal_combination_strategy.py,sha256=JSyuP71b1I1yuk0s2ecCJZTtCED85jBtkpwTUxibJvI,1706
86
- ddi_fw/pipeline/multi_pipeline.py,sha256=SZFJ9QSPD_3mcG9NHZOtMqKyNvyWrodsdsLryMyDdUw,8686
87
- ddi_fw/pipeline/ner_pipeline.py,sha256=Bp6BA6nozfWFaMHH6jKlzesnCGO6qiMkzdGy_ed6nh0,5947
86
+ ddi_fw/pipeline/multi_pipeline.py,sha256=AbErwu05-3YIPnCcXRsj-jxPJG8HG2H7cMZlGjzaYa8,9037
87
+ ddi_fw/pipeline/multi_pipeline_v2.py,sha256=7IGtaGFhgJqW29a6nDheUrVtn_7_xvWFdD6GC--sehM,10003
88
+ ddi_fw/pipeline/ner_pipeline.py,sha256=yp-Met2794EKcgr8_3gqt03l4v2efOdaZuAcIXTubvQ,5780
88
89
  ddi_fw/pipeline/pipeline.py,sha256=YhUBVLC29ZD2tmVd0e8X1FVBLhSKECZL2OP57oEW6HE,9171
89
90
  ddi_fw/utils/__init__.py,sha256=WNxkQXk-694roG50D355TGLXstfdWVb_tUyr-PM-8rg,537
90
91
  ddi_fw/utils/categorical_data_encoding_checker.py,sha256=T1X70Rh4atucAuqyUZmz-iFULllY9dY0NRyV9-jTjJ0,3438
@@ -99,7 +100,7 @@ ddi_fw/utils/zip_helper.py,sha256=YRZA4tKZVBJwGQM0_WK6L-y5MoqkKoC-nXuuHK6CU9I,55
99
100
  ddi_fw/vectorization/__init__.py,sha256=LcJOpLVoLvHPDw9phGFlUQGeNcST_zKV-Oi1Pm5h_nE,110
100
101
  ddi_fw/vectorization/feature_vector_generation.py,sha256=EBf-XAiwQwr68az91erEYNegfeqssBR29kVgrliIyac,4765
101
102
  ddi_fw/vectorization/idf_helper.py,sha256=_Gd1dtDSLaw8o-o0JugzSKMt9FpeXewTh4wGEaUd4VQ,2571
102
- ddi_fw-0.0.214.dist-info/METADATA,sha256=IEDJdH40Nw4B0aJXnUwuxeNRdXMX5rw1RBsX93Zbj1A,2631
103
- ddi_fw-0.0.214.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
104
- ddi_fw-0.0.214.dist-info/top_level.txt,sha256=PMwHICFZTZtcpzQNPV4UQnfNXYIeLR_Ste-Wfc1h810,7
105
- ddi_fw-0.0.214.dist-info/RECORD,,
103
+ ddi_fw-0.0.215.dist-info/METADATA,sha256=A_IcjN3qyvwbfFH-51g6NGR-8DUmwWEHWPu4sLfGuzc,2631
104
+ ddi_fw-0.0.215.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
105
+ ddi_fw-0.0.215.dist-info/top_level.txt,sha256=PMwHICFZTZtcpzQNPV4UQnfNXYIeLR_Ste-Wfc1h810,7
106
+ ddi_fw-0.0.215.dist-info/RECORD,,