ddi-fw 0.0.146__py3-none-any.whl → 0.0.148__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -8,7 +8,7 @@ from ddi_fw.utils import ZipHelper
8
8
 
9
9
  from .. import BaseDataset
10
10
  from ddi_fw.langchain.embeddings import PoolingStrategy
11
- # from ..db_utils import create_connection
11
+ from ..db_utils import create_connection
12
12
  # from ..db_utils import create_connection, select_all_drugs_as_dataframe, select_events_with_category
13
13
 
14
14
  HERE = pathlib.Path(__file__).resolve().parent
@@ -37,8 +37,32 @@ class MDFSADDIDataset(BaseDataset):
37
37
  ner_columns=[],
38
38
  **kwargs):
39
39
 
40
- super().__init__(chemical_property_columns, embedding_columns,
41
- ner_columns, **kwargs)
40
+ columns = kwargs['columns']
41
+ if columns:
42
+ chemical_property_columns = []
43
+ embedding_columns=[]
44
+ ner_columns=[]
45
+ for column in columns:
46
+ if column in list_of_chemical_property_columns:
47
+ chemical_property_columns.append(column)
48
+ elif column in list_of_embedding_columns:
49
+ embedding_columns.append(column)
50
+ elif column in list_of_ner_columns:
51
+ ner_columns.append(column)
52
+ # elif column == 'smile_2':
53
+ # continue
54
+ else:
55
+ raise Exception(f"{column} is not related this dataset")
56
+
57
+
58
+ super().__init__(embedding_size=embedding_size,
59
+ embedding_dict=embedding_dict,
60
+ embeddings_pooling_strategy=embeddings_pooling_strategy,
61
+ ner_df=ner_df,
62
+ chemical_property_columns=chemical_property_columns,
63
+ embedding_columns=embedding_columns,
64
+ ner_columns=ner_columns,
65
+ **kwargs)
42
66
 
43
67
  db_zip_path = HERE.joinpath('mdf-sa-ddi.zip')
44
68
  db_path = HERE.joinpath('mdf-sa-ddi.db')
@@ -50,7 +74,8 @@ class MDFSADDIDataset(BaseDataset):
50
74
  conn = create_connection(db_path)
51
75
  self.drugs_df = select_all_drugs_as_dataframe(conn)
52
76
  self.ddis_df = select_all_events_as_dataframe(conn)
53
- kwargs = {'index_path': str(HERE.joinpath('indexes'))}
77
+ # kwargs = {'index_path': str(HERE.joinpath('indexes'))}
78
+ kwargs['index_path'] = str(HERE.joinpath('indexes'))
54
79
 
55
80
  self.index_path = kwargs.get('index_path')
56
81
 
@@ -40,14 +40,24 @@ class TFModelWrapper(ModelWrapper):
40
40
  early_stopping = EarlyStopping(
41
41
  monitor='val_loss', patience=10, mode='auto')
42
42
  custom_callback = CustomCallback()
43
-
43
+ train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train))
44
+ val_dataset = tf.data.Dataset.from_tensor_slices((X_valid, y_valid))
45
+ train_dataset = train_dataset.batch(batch_size=self.batch_size)
46
+ val_dataset = val_dataset.batch(batch_size=self.batch_size)
44
47
  history = model.fit(
45
- X_train, y_train,
46
- batch_size=self.batch_size,
48
+ train_data = train_dataset,
47
49
  epochs=self.epochs,
48
- validation_data=(X_valid, y_valid),
50
+ validation_data=val_dataset,
49
51
  callbacks=[early_stopping, checkpoint, custom_callback]
50
52
  )
53
+ # ex
54
+ # history = model.fit(
55
+ # X_train, y_train,
56
+ # batch_size=self.batch_size,
57
+ # epochs=self.epochs,
58
+ # validation_data=(X_valid, y_valid),
59
+ # callbacks=[early_stopping, checkpoint, custom_callback]
60
+ # )
51
61
 
52
62
  if os.path.exists(f'{self.descriptor}_validation.weights.h5'):
53
63
  os.remove(f'{self.descriptor}_validation.weights.h5')
@@ -77,7 +87,10 @@ class TFModelWrapper(ModelWrapper):
77
87
  # https://github.com/mlflow/mlflow/blob/master/examples/tensorflow/train.py
78
88
 
79
89
  def predict(self):
80
- pred = self.best_model.predict(self.test_data)
90
+ test_dataset = tf.data.Dataset.from_tensor_slices((self.test_data, self.test_label))
91
+ test_dataset = test_dataset.batch(batch_size=1)
92
+ # pred = self.best_model.predict(self.test_data)
93
+ pred = self.best_model.predict(test_dataset)
81
94
  return pred
82
95
 
83
96
  def fit_and_evaluate(self):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ddi_fw
3
- Version: 0.0.146
3
+ Version: 0.0.148
4
4
  Summary: Do not use :)
5
5
  Author-email: Kıvanç Bayraktar <bayraktarkivanc@gmail.com>
6
6
  Maintainer-email: Kıvanç Bayraktar <bayraktarkivanc@gmail.com>
@@ -47,7 +47,7 @@ ddi_fw/datasets/ddi_mdl_text/indexes/validation_fold_2.txt,sha256=fFJbN0DbKH4mve
47
47
  ddi_fw/datasets/ddi_mdl_text/indexes/validation_fold_3.txt,sha256=NhiLF_5INQCpjOlE-RIxDKy7rYwksLdx60L6HCmDKoY,81247
48
48
  ddi_fw/datasets/ddi_mdl_text/indexes/validation_fold_4.txt,sha256=bPvMCJVy7jtcaYbR-5bmdB6s7gT8NSfK2wDC7iJ0O10,81308
49
49
  ddi_fw/datasets/mdf_sa_ddi/__init__.py,sha256=UEFBM92y2aJjlMJw4Jx405tOAwJ88r_nHAVgAszSjuo,68
50
- ddi_fw/datasets/mdf_sa_ddi/base.py,sha256=ShsDALf0lI4SDtXAmhMKOM05b2q_LStYPVXn12S9PTE,5371
50
+ ddi_fw/datasets/mdf_sa_ddi/base.py,sha256=kYNmtg-s0V7mP-wjLMaAstNCG3vckMPQSE651RA_LAE,6502
51
51
  ddi_fw/datasets/mdf_sa_ddi/df_extraction_cleanxiaoyu50.csv,sha256=EOOLF_0vVVzShoofcGYlOzpztlM1m9jJdftepHicix4,25787699
52
52
  ddi_fw/datasets/mdf_sa_ddi/drug_information_del_noDDIxiaoyu50.csv,sha256=lpuMz5KxPsG6MKNuIIUmT5cZquWHQiIao8tXlmOHzq8,381321
53
53
  ddi_fw/datasets/mdf_sa_ddi/mdf-sa-ddi.zip,sha256=DfN8mczGvWba2y45cPqtWtXjUDXy49VOtRfpcb0tn8c,4382827
@@ -78,7 +78,7 @@ ddi_fw/ml/evaluation_helper.py,sha256=o4-w5Xa3t4olLW4ymx_8L-Buhe5wfQEmT2bh4Zz544
78
78
  ddi_fw/ml/ml_helper.py,sha256=xSEa_UNpaFyrPswlQcDfZSI2x5nZLStOiKoP54SYkCM,6454
79
79
  ddi_fw/ml/model_wrapper.py,sha256=kc01_TVJuriUvNI6ABnLngnJWvmG_Y7-XJ6XMusLJ8U,1088
80
80
  ddi_fw/ml/pytorch_wrapper.py,sha256=AkG-2sKDXr0IBhgmkbjG0i20OuwQv3mhdvqp6UvJDCA,3716
81
- ddi_fw/ml/tensorflow_wrapper.py,sha256=ECLD5bl1sHKEwTvwkHHCRBV70Wmbxfejd1ix0Gbrh1g,5649
81
+ ddi_fw/ml/tensorflow_wrapper.py,sha256=oa9VEZpoHRXVoBKHfTclaVyksvF_6BVuMPeOS3-uJ2E,6409
82
82
  ddi_fw/ner/__init__.py,sha256=JwhGXrepomxPSsGsg2b_xPRC72AjvxOIn2CW5Mvscn0,26
83
83
  ddi_fw/ner/mmlrestclient.py,sha256=NZta7m2Qm6I_qtVguMZhqtAUjVBmmXn0-TMnsNp0jpg,6859
84
84
  ddi_fw/ner/ner.py,sha256=BEs9AFljAxOQrC2BEP1raSzRoypcfELS5UTdl4bjTqw,15863
@@ -106,7 +106,7 @@ ddi_fw/utils/package_helper.py,sha256=erl8_onmhK-41zQoaED2qyDUV9GQxmT9sdoyRp9_q5
106
106
  ddi_fw/utils/py7zr_helper.py,sha256=gOqaFIyJvTjUM-btO2x9AQ69jZOS8PoKN0wetYIckJw,4747
107
107
  ddi_fw/utils/utils.py,sha256=szwnxMTDRrZoeNRyDuf3aCbtzriwtaRk4mHSH3asLdA,4301
108
108
  ddi_fw/utils/zip_helper.py,sha256=YRZA4tKZVBJwGQM0_WK6L-y5MoqkKoC-nXuuHK6CU9I,5567
109
- ddi_fw-0.0.146.dist-info/METADATA,sha256=kyAlbvaawCIxFxHd1bltCpOhaXP4QeCnriPF258iOzI,1965
110
- ddi_fw-0.0.146.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
111
- ddi_fw-0.0.146.dist-info/top_level.txt,sha256=PMwHICFZTZtcpzQNPV4UQnfNXYIeLR_Ste-Wfc1h810,7
112
- ddi_fw-0.0.146.dist-info/RECORD,,
109
+ ddi_fw-0.0.148.dist-info/METADATA,sha256=TurH33n4534cFd-LUj63n152bTSpTJPCjINRtO21ILo,1965
110
+ ddi_fw-0.0.148.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
111
+ ddi_fw-0.0.148.dist-info/top_level.txt,sha256=PMwHICFZTZtcpzQNPV4UQnfNXYIeLR_Ste-Wfc1h810,7
112
+ ddi_fw-0.0.148.dist-info/RECORD,,