ddi-fw 0.0.197__py3-none-any.whl → 0.0.198__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ddi_fw/datasets/core.py +55 -41
- ddi_fw/datasets/ddi_mdl/base.py +11 -9
- {ddi_fw-0.0.197.dist-info → ddi_fw-0.0.198.dist-info}/METADATA +1 -1
- {ddi_fw-0.0.197.dist-info → ddi_fw-0.0.198.dist-info}/RECORD +6 -6
- {ddi_fw-0.0.197.dist-info → ddi_fw-0.0.198.dist-info}/WHEEL +0 -0
- {ddi_fw-0.0.197.dist-info → ddi_fw-0.0.198.dist-info}/top_level.txt +0 -0
ddi_fw/datasets/core.py
CHANGED
@@ -1,3 +1,4 @@
|
|
1
|
+
import abc
|
1
2
|
from collections import defaultdict
|
2
3
|
import glob
|
3
4
|
import logging
|
@@ -57,7 +58,7 @@ def generate_sim_matrices_new(df, generated_vectors, columns, key_column="id"):
|
|
57
58
|
return similarity_matrices
|
58
59
|
|
59
60
|
|
60
|
-
class BaseDataset(BaseModel):
|
61
|
+
class BaseDataset(BaseModel, abc.ABC):
|
61
62
|
dataset_name: str
|
62
63
|
index_path: Optional[str] = None
|
63
64
|
dataset_splitter_type: Type[DatasetSplitter]
|
@@ -125,19 +126,26 @@ class BaseDataset(BaseModel):
|
|
125
126
|
def set_dataframe(self, dataframe: pd.DataFrame):
|
126
127
|
self.dataframe = dataframe
|
127
128
|
|
128
|
-
|
129
|
+
@abc.abstractmethod
|
129
130
|
def prep(self):
|
130
|
-
|
131
|
+
"""Prepare the dataset. This method should be overridden in subclasses."""
|
132
|
+
|
131
133
|
|
134
|
+
def handle_mixins(self):
|
135
|
+
"""Handle mixin-specific logic."""
|
136
|
+
if isinstance(self, TextDatasetMixin):
|
137
|
+
self.process_text()
|
138
|
+
# if isinstance(self, ImageDatasetMixin):
|
139
|
+
# self.process_image_data()
|
140
|
+
# Add other mixin-specific logic here
|
141
|
+
|
132
142
|
def load(self):
|
133
143
|
"""
|
134
144
|
Load the dataset. If X_train, y_train, X_test, and y_test are already provided,
|
135
145
|
skip deriving them. Otherwise, derive them from the dataframe and indices.
|
136
146
|
"""
|
137
|
-
self.prep()
|
138
|
-
|
139
|
-
if isinstance(self, TextDatasetMixin):
|
140
|
-
self.process_text()
|
147
|
+
self.prep() # Prepare the dataset
|
148
|
+
self.handle_mixins() # Centralized mixin handling
|
141
149
|
|
142
150
|
if self.X_train is not None or self.y_train is not None or self.X_test is not None or self.y_test is not None:
|
143
151
|
# Data is already provided, no need to calculate
|
@@ -158,9 +166,11 @@ class BaseDataset(BaseModel):
|
|
158
166
|
self.index_path)
|
159
167
|
except FileNotFoundError as e:
|
160
168
|
raise FileNotFoundError(f"Index files not found: {e.filename}")
|
161
|
-
|
162
|
-
train = self.dataframe[self.dataframe.index.isin(train_idx_all)]
|
163
|
-
test = self.dataframe[self.dataframe.index.isin(test_idx_all)]
|
169
|
+
|
170
|
+
# train = self.dataframe[self.dataframe.index.isin(train_idx_all)]
|
171
|
+
# test = self.dataframe[self.dataframe.index.isin(test_idx_all)]
|
172
|
+
train = self.dataframe.loc[self.dataframe.index.isin(train_idx_all), self.columns]
|
173
|
+
test = self.dataframe.loc[self.dataframe.index.isin(test_idx_all), self.columns]
|
164
174
|
X_train = train.drop(self.class_column, axis=1)
|
165
175
|
X_train = train.drop(self.class_column, axis=1)
|
166
176
|
y_train = train[self.class_column]
|
@@ -259,13 +269,18 @@ class BaseDataset(BaseModel):
|
|
259
269
|
|
260
270
|
|
261
271
|
class TextDatasetMixin(BaseModel):
|
262
|
-
embedding_size: Optional[int] = None
|
263
272
|
embedding_dict: Dict[str, Any] | None = Field(
|
264
273
|
default_factory=dict, description="Dictionary for embeddings")
|
265
274
|
pooling_strategy: PoolingStrategy | None = None
|
266
275
|
column_embedding_configs: Optional[Dict] = None
|
267
276
|
vector_db_persist_directory: Optional[str] = None
|
268
277
|
vector_db_collection_name: Optional[str] = None
|
278
|
+
_embedding_size: int
|
279
|
+
|
280
|
+
@computed_field
|
281
|
+
@property
|
282
|
+
def embedding_size(self) -> int:
|
283
|
+
return self._embedding_size
|
269
284
|
|
270
285
|
class Config:
|
271
286
|
arbitrary_types_allowed = True
|
@@ -317,44 +332,43 @@ class TextDatasetMixin(BaseModel):
|
|
317
332
|
else:
|
318
333
|
raise ValueError(
|
319
334
|
"Persistent directory for the vector DB is not specified.")
|
335
|
+
|
336
|
+
def __initialize_embedding_dict(self):
|
337
|
+
embedding_dict = defaultdict(lambda: defaultdict(list))
|
338
|
+
if self.column_embedding_configs:
|
339
|
+
for item in self.column_embedding_configs:
|
340
|
+
col = item["column"]
|
341
|
+
col_db_dir = item["vector_db_persist_directory"]
|
342
|
+
col_db_collection = item["vector_db_collection_name"]
|
343
|
+
self.__create_or_update_embeddings__(embedding_dict, col_db_dir, col_db_collection, col)
|
344
|
+
elif self.vector_db_persist_directory:
|
345
|
+
self.__create_or_update_embeddings__(embedding_dict, self.vector_db_persist_directory, self.vector_db_collection_name)
|
346
|
+
else:
|
347
|
+
logging.warning("There is no configuration of Embeddings")
|
348
|
+
raise ValueError(
|
349
|
+
"There is no configuration of Embeddings. Please provide a vector database directory and collection name.")
|
350
|
+
return embedding_dict
|
320
351
|
|
321
|
-
def
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
352
|
+
def __calculate_embedding_size(self):
|
353
|
+
if self.embedding_dict is None:
|
354
|
+
raise ValueError("Embedding dictionary is not initialized, embedding size cannot be calculated.")
|
355
|
+
|
356
|
+
key, value = next(iter(self.embedding_dict.items()))
|
357
|
+
self._embedding_size = value[next(iter(value))][0].shape[0]
|
326
358
|
|
359
|
+
def process_text(self):
|
360
|
+
logging.info("Processing text data...")
|
361
|
+
|
327
362
|
# 'enzyme','target','pathway','smile','all_text','indication', 'description','mechanism_of_action','pharmacodynamics', 'tui', 'cui', 'entities'
|
328
363
|
# kwargs = {"columns": self.columns}
|
329
364
|
# if self.ner_threshold:
|
330
365
|
# for k, v in self.ner_threshold.items():
|
331
366
|
# kwargs[k] = v
|
332
|
-
if self.embedding_dict
|
333
|
-
embedding_dict =
|
334
|
-
# TODO find more effective solution
|
335
|
-
|
336
|
-
if self.column_embedding_configs:
|
337
|
-
for item in self.column_embedding_configs:
|
338
|
-
col = item["column"]
|
339
|
-
col_db_dir = item["vector_db_persist_directory"]
|
340
|
-
col_db_collection = item["vector_db_collection_name"]
|
341
|
-
self.__create_or_update_embeddings__(
|
342
|
-
embedding_dict, col_db_dir, col_db_collection, col)
|
343
|
-
|
344
|
-
elif self.vector_db_persist_directory:
|
345
|
-
self.__create_or_update_embeddings__(
|
346
|
-
embedding_dict, self.vector_db_persist_directory, self.vector_db_collection_name)
|
367
|
+
if self.embedding_dict is None:
|
368
|
+
self.embedding_dict = self.__initialize_embedding_dict()
|
347
369
|
|
348
|
-
|
349
|
-
|
350
|
-
f"There is no configuration of Embeddings")
|
351
|
-
self.embedding_dict = embedding_dict
|
352
|
-
|
353
|
-
# else:
|
354
|
-
# embedding_dict = self.embedding_dict
|
355
|
-
# TODO make generic
|
356
|
-
# embedding_size = list(embedding_dict['all_text'].values())[
|
357
|
-
# 0][0].shape
|
370
|
+
self.__calculate_embedding_size()
|
371
|
+
|
358
372
|
|
359
373
|
|
360
374
|
# class ImageDatasetMixin(BaseModel):
|
ddi_fw/datasets/ddi_mdl/base.py
CHANGED
@@ -91,7 +91,7 @@ class DDIMDLDataset(BaseDataset,TextDatasetMixin):
|
|
91
91
|
self.__similarity_related_columns__.extend(self.ner_columns)
|
92
92
|
# TODO with resource
|
93
93
|
self._conn = create_connection(_db_path.absolute().as_posix())
|
94
|
-
self.load_drugs_and_events()
|
94
|
+
# self.load_drugs_and_events()
|
95
95
|
logger.info(f'{self.dataset_name} is initialized')
|
96
96
|
|
97
97
|
def load_drugs_and_events(self):
|
@@ -131,6 +131,7 @@ class DDIMDLDataset(BaseDataset,TextDatasetMixin):
|
|
131
131
|
return pd.DataFrame(columns=headers, data=rows)
|
132
132
|
|
133
133
|
def prep(self):
|
134
|
+
self.load_drugs_and_events()
|
134
135
|
if self.drugs_df is None or self.ddis_df is None:
|
135
136
|
raise Exception("There is no data")
|
136
137
|
|
@@ -220,14 +221,15 @@ class DDIMDLDataset(BaseDataset,TextDatasetMixin):
|
|
220
221
|
self.columns.append(key)
|
221
222
|
print(self.ddis_df[key].head())
|
222
223
|
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
224
|
+
if self.embedding_dict is not None:
|
225
|
+
for embedding_column in self.embedding_columns:
|
226
|
+
print(f"concat {embedding_column} embeddings")
|
227
|
+
embeddings_after_pooling = {k: self.pooling_strategy.apply(
|
228
|
+
v) for k, v in self.embedding_dict[embedding_column].items()}
|
229
|
+
# column_embeddings_dict = embedding_values[embedding_column]
|
230
|
+
self.ddis_df[embedding_column+'_embedding'] = self.ddis_df.apply(
|
231
|
+
x_fnc, args=(embeddings_after_pooling,), axis=1)
|
232
|
+
self.columns.append(embedding_column+'_embedding')
|
231
233
|
|
232
234
|
dataframe = self.ddis_df.copy()
|
233
235
|
if not isinstance(classes, (list, pd.Series, np.ndarray)):
|
@@ -1,9 +1,9 @@
|
|
1
1
|
ddi_fw/datasets/__init__.py,sha256=_I3iDHARwzmg7_EL5XKtB_TgG1yAkLSOVTujLL9Wz9Q,280
|
2
|
-
ddi_fw/datasets/core.py,sha256=
|
2
|
+
ddi_fw/datasets/core.py,sha256=yfnJwyF9oV2RUErFSAKSyxQQeL1tmLiq7SfADhn1Cgk,16379
|
3
3
|
ddi_fw/datasets/dataset_splitter.py,sha256=8H8uZTAf8N9LUZeSeHOMawtJFJhnDgUUqFcnl7dquBQ,1672
|
4
4
|
ddi_fw/datasets/db_utils.py,sha256=OTsa3d-Iic7z3HmzSQK9UigedRbHDxYChJk0s4GfLnw,6191
|
5
5
|
ddi_fw/datasets/setup_._py,sha256=khYVJuW5PlOY_i_A16F3UbSZ6s6o_ljw33Byw3C-A8E,1047
|
6
|
-
ddi_fw/datasets/ddi_mdl/base.py,sha256=
|
6
|
+
ddi_fw/datasets/ddi_mdl/base.py,sha256=8WFc0iLT5PF6IOUStqKVIKR74D8WBuwXm_uMiV4OFsk,10324
|
7
7
|
ddi_fw/datasets/ddi_mdl/debug.log,sha256=eWz05j8RFqZuHFDTCF7Rck5w4rvtTanFN21iZsgxO7Y,115
|
8
8
|
ddi_fw/datasets/ddi_mdl/readme.md,sha256=WC6lpmsEKvIISnZqENY7TWtzCQr98HPpE3oRsBl8pIw,625
|
9
9
|
ddi_fw/datasets/ddi_mdl/data/event.db,sha256=cmlSsf9MYjRzqR-mw3cUDnTnfT6FkpOG2yCl2mMwwew,30580736
|
@@ -99,7 +99,7 @@ ddi_fw/utils/zip_helper.py,sha256=YRZA4tKZVBJwGQM0_WK6L-y5MoqkKoC-nXuuHK6CU9I,55
|
|
99
99
|
ddi_fw/vectorization/__init__.py,sha256=LcJOpLVoLvHPDw9phGFlUQGeNcST_zKV-Oi1Pm5h_nE,110
|
100
100
|
ddi_fw/vectorization/feature_vector_generation.py,sha256=EBf-XAiwQwr68az91erEYNegfeqssBR29kVgrliIyac,4765
|
101
101
|
ddi_fw/vectorization/idf_helper.py,sha256=_Gd1dtDSLaw8o-o0JugzSKMt9FpeXewTh4wGEaUd4VQ,2571
|
102
|
-
ddi_fw-0.0.
|
103
|
-
ddi_fw-0.0.
|
104
|
-
ddi_fw-0.0.
|
105
|
-
ddi_fw-0.0.
|
102
|
+
ddi_fw-0.0.198.dist-info/METADATA,sha256=z3otymNU3l4737h3tkMaP0UMhZdLBtzS4ELP4wIcVt8,2542
|
103
|
+
ddi_fw-0.0.198.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
104
|
+
ddi_fw-0.0.198.dist-info/top_level.txt,sha256=PMwHICFZTZtcpzQNPV4UQnfNXYIeLR_Ste-Wfc1h810,7
|
105
|
+
ddi_fw-0.0.198.dist-info/RECORD,,
|
File without changes
|
File without changes
|