ddi-fw 0.0.196__py3-none-any.whl → 0.0.198__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ddi_fw/datasets/core.py CHANGED
@@ -1,3 +1,4 @@
1
+ import abc
1
2
  from collections import defaultdict
2
3
  import glob
3
4
  import logging
@@ -57,7 +58,7 @@ def generate_sim_matrices_new(df, generated_vectors, columns, key_column="id"):
57
58
  return similarity_matrices
58
59
 
59
60
 
60
- class BaseDataset(BaseModel):
61
+ class BaseDataset(BaseModel, abc.ABC):
61
62
  dataset_name: str
62
63
  index_path: Optional[str] = None
63
64
  dataset_splitter_type: Type[DatasetSplitter]
@@ -125,19 +126,26 @@ class BaseDataset(BaseModel):
125
126
  def set_dataframe(self, dataframe: pd.DataFrame):
126
127
  self.dataframe = dataframe
127
128
 
128
- # @abstractmethod
129
+ @abc.abstractmethod
129
130
  def prep(self):
130
- pass
131
+ """Prepare the dataset. This method should be overridden in subclasses."""
132
+
131
133
 
134
+ def handle_mixins(self):
135
+ """Handle mixin-specific logic."""
136
+ if isinstance(self, TextDatasetMixin):
137
+ self.process_text()
138
+ # if isinstance(self, ImageDatasetMixin):
139
+ # self.process_image_data()
140
+ # Add other mixin-specific logic here
141
+
132
142
  def load(self):
133
143
  """
134
144
  Load the dataset. If X_train, y_train, X_test, and y_test are already provided,
135
145
  skip deriving them. Otherwise, derive them from the dataframe and indices.
136
146
  """
137
- self.prep()
138
-
139
- if isinstance(self, TextDatasetMixin):
140
- self.process_text()
147
+ self.prep() # Prepare the dataset
148
+ self.handle_mixins() # Centralized mixin handling
141
149
 
142
150
  if self.X_train is not None or self.y_train is not None or self.X_test is not None or self.y_test is not None:
143
151
  # Data is already provided, no need to calculate
@@ -158,9 +166,11 @@ class BaseDataset(BaseModel):
158
166
  self.index_path)
159
167
  except FileNotFoundError as e:
160
168
  raise FileNotFoundError(f"Index files not found: {e.filename}")
161
-
162
- train = self.dataframe[self.dataframe.index.isin(train_idx_all)]
163
- test = self.dataframe[self.dataframe.index.isin(test_idx_all)]
169
+
170
+ # train = self.dataframe[self.dataframe.index.isin(train_idx_all)]
171
+ # test = self.dataframe[self.dataframe.index.isin(test_idx_all)]
172
+ train = self.dataframe.loc[self.dataframe.index.isin(train_idx_all), self.columns]
173
+ test = self.dataframe.loc[self.dataframe.index.isin(test_idx_all), self.columns]
164
174
  X_train = train.drop(self.class_column, axis=1)
165
175
  X_train = train.drop(self.class_column, axis=1)
166
176
  y_train = train[self.class_column]
@@ -259,13 +269,21 @@ class BaseDataset(BaseModel):
259
269
 
260
270
 
261
271
  class TextDatasetMixin(BaseModel):
262
- embedding_size: Optional[int] = None
263
272
  embedding_dict: Dict[str, Any] | None = Field(
264
273
  default_factory=dict, description="Dictionary for embeddings")
265
274
  pooling_strategy: PoolingStrategy | None = None
266
275
  column_embedding_configs: Optional[Dict] = None
267
276
  vector_db_persist_directory: Optional[str] = None
268
277
  vector_db_collection_name: Optional[str] = None
278
+ _embedding_size: int
279
+
280
+ @computed_field
281
+ @property
282
+ def embedding_size(self) -> int:
283
+ return self._embedding_size
284
+
285
+ class Config:
286
+ arbitrary_types_allowed = True
269
287
 
270
288
  def __create_or_update_embeddings__(self, embedding_dict, vector_db_persist_directory, vector_db_collection_name, column=None):
271
289
  """
@@ -314,45 +332,43 @@ class TextDatasetMixin(BaseModel):
314
332
  else:
315
333
  raise ValueError(
316
334
  "Persistent directory for the vector DB is not specified.")
317
-
335
+
336
+ def __initialize_embedding_dict(self):
337
+ embedding_dict = defaultdict(lambda: defaultdict(list))
338
+ if self.column_embedding_configs:
339
+ for item in self.column_embedding_configs:
340
+ col = item["column"]
341
+ col_db_dir = item["vector_db_persist_directory"]
342
+ col_db_collection = item["vector_db_collection_name"]
343
+ self.__create_or_update_embeddings__(embedding_dict, col_db_dir, col_db_collection, col)
344
+ elif self.vector_db_persist_directory:
345
+ self.__create_or_update_embeddings__(embedding_dict, self.vector_db_persist_directory, self.vector_db_collection_name)
346
+ else:
347
+ logging.warning("There is no configuration of Embeddings")
348
+ raise ValueError(
349
+ "There is no configuration of Embeddings. Please provide a vector database directory and collection name.")
350
+ return embedding_dict
351
+
352
+ def __calculate_embedding_size(self):
353
+ if self.embedding_dict is None:
354
+ raise ValueError("Embedding dictionary is not initialized, embedding size cannot be calculated.")
355
+
356
+ key, value = next(iter(self.embedding_dict.items()))
357
+ self._embedding_size = value[next(iter(value))][0].shape[0]
358
+
318
359
  def process_text(self):
319
- # key, value = next(iter(embedding_dict.items()))
320
- # embedding_size = value[next(iter(value))][0].shape[0]
321
- # pooling_strategy = self.embedding_pooling_strategy_type(
322
- # ) if self.embedding_pooling_strategy_type else None
323
-
324
-
360
+ logging.info("Processing text data...")
361
+
325
362
  # 'enzyme','target','pathway','smile','all_text','indication', 'description','mechanism_of_action','pharmacodynamics', 'tui', 'cui', 'entities'
326
363
  # kwargs = {"columns": self.columns}
327
364
  # if self.ner_threshold:
328
365
  # for k, v in self.ner_threshold.items():
329
366
  # kwargs[k] = v
330
- if self.embedding_dict == None:
331
- embedding_dict = defaultdict(lambda: defaultdict(list))
332
- # TODO find more effective solution
333
-
334
- if self.column_embedding_configs:
335
- for item in self.column_embedding_configs:
336
- col = item["column"]
337
- col_db_dir = item["vector_db_persist_directory"]
338
- col_db_collection = item["vector_db_collection_name"]
339
- self.__create_or_update_embeddings__(
340
- embedding_dict, col_db_dir, col_db_collection, col)
341
-
342
- elif self.vector_db_persist_directory:
343
- self.__create_or_update_embeddings__(
344
- embedding_dict, self.vector_db_persist_directory, self.vector_db_collection_name)
367
+ if self.embedding_dict is None:
368
+ self.embedding_dict = self.__initialize_embedding_dict()
345
369
 
346
- else:
347
- print(
348
- f"There is no configuration of Embeddings")
349
- self.embedding_dict = embedding_dict
350
-
351
- # else:
352
- # embedding_dict = self.embedding_dict
353
- # TODO make generic
354
- # embedding_size = list(embedding_dict['all_text'].values())[
355
- # 0][0].shape
370
+ self.__calculate_embedding_size()
371
+
356
372
 
357
373
 
358
374
  # class ImageDatasetMixin(BaseModel):
@@ -91,7 +91,7 @@ class DDIMDLDataset(BaseDataset,TextDatasetMixin):
91
91
  self.__similarity_related_columns__.extend(self.ner_columns)
92
92
  # TODO with resource
93
93
  self._conn = create_connection(_db_path.absolute().as_posix())
94
- self.load_drugs_and_events()
94
+ # self.load_drugs_and_events()
95
95
  logger.info(f'{self.dataset_name} is initialized')
96
96
 
97
97
  def load_drugs_and_events(self):
@@ -131,6 +131,7 @@ class DDIMDLDataset(BaseDataset,TextDatasetMixin):
131
131
  return pd.DataFrame(columns=headers, data=rows)
132
132
 
133
133
  def prep(self):
134
+ self.load_drugs_and_events()
134
135
  if self.drugs_df is None or self.ddis_df is None:
135
136
  raise Exception("There is no data")
136
137
 
@@ -220,14 +221,15 @@ class DDIMDLDataset(BaseDataset,TextDatasetMixin):
220
221
  self.columns.append(key)
221
222
  print(self.ddis_df[key].head())
222
223
 
223
- for embedding_column in self.embedding_columns:
224
- print(f"concat {embedding_column} embeddings")
225
- embeddings_after_pooling = {k: self.embeddings_pooling_strategy.apply(
226
- v) for k, v in self.embedding_dict[embedding_column].items()}
227
- # column_embeddings_dict = embedding_values[embedding_column]
228
- self.ddis_df[embedding_column+'_embedding'] = self.ddis_df.apply(
229
- x_fnc, args=(embeddings_after_pooling,), axis=1)
230
- self.columns.append(embedding_column+'_embedding')
224
+ if self.embedding_dict is not None:
225
+ for embedding_column in self.embedding_columns:
226
+ print(f"concat {embedding_column} embeddings")
227
+ embeddings_after_pooling = {k: self.pooling_strategy.apply(
228
+ v) for k, v in self.embedding_dict[embedding_column].items()}
229
+ # column_embeddings_dict = embedding_values[embedding_column]
230
+ self.ddis_df[embedding_column+'_embedding'] = self.ddis_df.apply(
231
+ x_fnc, args=(embeddings_after_pooling,), axis=1)
232
+ self.columns.append(embedding_column+'_embedding')
231
233
 
232
234
  dataframe = self.ddis_df.copy()
233
235
  if not isinstance(classes, (list, pd.Series, np.ndarray)):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ddi_fw
3
- Version: 0.0.196
3
+ Version: 0.0.198
4
4
  Summary: Do not use :)
5
5
  Author-email: Kıvanç Bayraktar <bayraktarkivanc@gmail.com>
6
6
  Maintainer-email: Kıvanç Bayraktar <bayraktarkivanc@gmail.com>
@@ -1,9 +1,9 @@
1
1
  ddi_fw/datasets/__init__.py,sha256=_I3iDHARwzmg7_EL5XKtB_TgG1yAkLSOVTujLL9Wz9Q,280
2
- ddi_fw/datasets/core.py,sha256=MGl3qg1Lo2-QuyE3vVT1t5e8x46MKCbkSF_ZIVdcAa0,15583
2
+ ddi_fw/datasets/core.py,sha256=yfnJwyF9oV2RUErFSAKSyxQQeL1tmLiq7SfADhn1Cgk,16379
3
3
  ddi_fw/datasets/dataset_splitter.py,sha256=8H8uZTAf8N9LUZeSeHOMawtJFJhnDgUUqFcnl7dquBQ,1672
4
4
  ddi_fw/datasets/db_utils.py,sha256=OTsa3d-Iic7z3HmzSQK9UigedRbHDxYChJk0s4GfLnw,6191
5
5
  ddi_fw/datasets/setup_._py,sha256=khYVJuW5PlOY_i_A16F3UbSZ6s6o_ljw33Byw3C-A8E,1047
6
- ddi_fw/datasets/ddi_mdl/base.py,sha256=rS8lSGE-SLeoE3GuElJ-TNaRHIGhaZBeOM2UH3JUS4M,10218
6
+ ddi_fw/datasets/ddi_mdl/base.py,sha256=8WFc0iLT5PF6IOUStqKVIKR74D8WBuwXm_uMiV4OFsk,10324
7
7
  ddi_fw/datasets/ddi_mdl/debug.log,sha256=eWz05j8RFqZuHFDTCF7Rck5w4rvtTanFN21iZsgxO7Y,115
8
8
  ddi_fw/datasets/ddi_mdl/readme.md,sha256=WC6lpmsEKvIISnZqENY7TWtzCQr98HPpE3oRsBl8pIw,625
9
9
  ddi_fw/datasets/ddi_mdl/data/event.db,sha256=cmlSsf9MYjRzqR-mw3cUDnTnfT6FkpOG2yCl2mMwwew,30580736
@@ -99,7 +99,7 @@ ddi_fw/utils/zip_helper.py,sha256=YRZA4tKZVBJwGQM0_WK6L-y5MoqkKoC-nXuuHK6CU9I,55
99
99
  ddi_fw/vectorization/__init__.py,sha256=LcJOpLVoLvHPDw9phGFlUQGeNcST_zKV-Oi1Pm5h_nE,110
100
100
  ddi_fw/vectorization/feature_vector_generation.py,sha256=EBf-XAiwQwr68az91erEYNegfeqssBR29kVgrliIyac,4765
101
101
  ddi_fw/vectorization/idf_helper.py,sha256=_Gd1dtDSLaw8o-o0JugzSKMt9FpeXewTh4wGEaUd4VQ,2571
102
- ddi_fw-0.0.196.dist-info/METADATA,sha256=XjBDIDDg_a_py1xN16C5uxvUcIg7yEP_Lp3k4FIEiDs,2542
103
- ddi_fw-0.0.196.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
104
- ddi_fw-0.0.196.dist-info/top_level.txt,sha256=PMwHICFZTZtcpzQNPV4UQnfNXYIeLR_Ste-Wfc1h810,7
105
- ddi_fw-0.0.196.dist-info/RECORD,,
102
+ ddi_fw-0.0.198.dist-info/METADATA,sha256=z3otymNU3l4737h3tkMaP0UMhZdLBtzS4ELP4wIcVt8,2542
103
+ ddi_fw-0.0.198.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
104
+ ddi_fw-0.0.198.dist-info/top_level.txt,sha256=PMwHICFZTZtcpzQNPV4UQnfNXYIeLR_Ste-Wfc1h810,7
105
+ ddi_fw-0.0.198.dist-info/RECORD,,