ddi-fw 0.0.179__py3-none-any.whl → 0.0.181__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -28,9 +28,10 @@ class TFModelWrapper(ModelWrapper):
28
28
  self.epochs = kwargs.get('epochs', 100)
29
29
  self.use_mlflow = use_mlflow
30
30
 
31
+ # TODO think different settings for num_classes
31
32
  def fit_model(self, X_train, y_train, X_valid, y_valid):
32
33
  self.kwargs['input_shape'] = self.train_data.shape
33
- self.num_classes = len(np.unique(y_train))
34
+ self.num_classes = len(np.unique(y_train, axis=0))
34
35
  model = self.model_func(**self.kwargs)
35
36
  checkpoint = ModelCheckpoint(
36
37
  filepath=f'{self.descriptor}_validation.weights.h5',
@@ -136,8 +137,11 @@ class TFModelWrapper(ModelWrapper):
136
137
  print(best_model_key)
137
138
  self.best_model: Model = best_model
138
139
  pred = self.predict()
139
- pred = tf.keras.utils.to_categorical(np.argmax(pred,axis=1), num_classes=self.num_classes)
140
- actual = tf.keras.utils.to_categorical(self.test_label, num_classes=self.num_classes)
140
+ actual = self.test_label
141
+ if not utils.is_binary_encoded(pred):
142
+ pred = tf.keras.utils.to_categorical(np.argmax(pred,axis=1), num_classes=self.num_classes)
143
+ if not utils.is_binary_encoded(actual):
144
+ actual = tf.keras.utils.to_categorical(actual, num_classes=self.num_classes)
141
145
 
142
146
  logs, metrics = evaluate(
143
147
  actual=actual, pred=pred, info=self.descriptor, print_detail=print_detail)
@@ -155,8 +159,11 @@ class TFModelWrapper(ModelWrapper):
155
159
  print(best_model_key)
156
160
  self.best_model = best_model
157
161
  pred = self.predict()
158
- pred = tf.keras.utils.to_categorical(np.argmax(pred,axis=1), num_classes=self.num_classes)
159
- actual = tf.keras.utils.to_categorical(self.test_label, num_classes=self.num_classes)
162
+ actual = self.test_label
163
+ if not utils.is_binary_encoded(pred):
164
+ pred = tf.keras.utils.to_categorical(np.argmax(pred,axis=1), num_classes=self.num_classes)
165
+ if not utils.is_binary_encoded(actual):
166
+ actual = tf.keras.utils.to_categorical(actual, num_classes=self.num_classes)
160
167
 
161
168
  logs, metrics = evaluate(
162
169
  actual=actual, pred=pred, info=self.descriptor)
ddi_fw/utils/__init__.py CHANGED
@@ -3,4 +3,5 @@ from .zip_helper import ZipHelper
3
3
  from .py7zr_helper import Py7ZipHelper
4
4
  from .enums import UMLSCodeTypes, DrugBankTextDataTypes
5
5
  from .package_helper import get_import
6
- from .kaggle import create_kaggle_dataset
6
+ from .kaggle import create_kaggle_dataset
7
+ from .categorical_data_encoding_checker import is_one_hot_encoded, is_binary_encoded, is_binary_vector
@@ -0,0 +1,32 @@
1
+ import numpy as np
2
+
3
+
4
+ def is_one_hot_encoded(arr):
5
+ # Check if the array is one-hot encoded
6
+ # Ensure the input is a numpy ndarray and is 2D
7
+ if not isinstance(arr, np.ndarray):
8
+ raise ValueError("Input must be a NumPy ndarray.")
9
+ if not np.all(np.isin(arr, [0, 1])):
10
+ return False
11
+ # Check if each row (or column) has exactly one "1"
12
+ return np.all(np.sum(arr, axis=1) == 1) # For row-wise checking
13
+
14
+
15
+ def is_binary_encoded(arr):
16
+ # Ensure the input is a numpy ndarray and is 2D
17
+ if not isinstance(arr, np.ndarray):
18
+ raise ValueError("Input must be a NumPy ndarray.")
19
+ if arr.ndim != 2:
20
+ raise ValueError("Input must be a 2D array.")
21
+
22
+ # Check if all elements are either 0 or 1
23
+ return np.all(np.isin(arr, [0, 1]))
24
+
25
+
26
+ def is_binary_vector(arr):
27
+ # Ensure the input is a numpy ndarray and is 1D
28
+ if not isinstance(arr, np.ndarray):
29
+ raise ValueError("Input must be a NumPy ndarray.")
30
+ if arr.ndim != 1:
31
+ raise ValueError("Input must be a 1D array.")
32
+ return arr.ndim == 1 and np.all(np.isin(arr, [0, 1]))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ddi_fw
3
- Version: 0.0.179
3
+ Version: 0.0.181
4
4
  Summary: Do not use :)
5
5
  Author-email: Kıvanç Bayraktar <bayraktarkivanc@gmail.com>
6
6
  Maintainer-email: Kıvanç Bayraktar <bayraktarkivanc@gmail.com>
@@ -77,7 +77,7 @@ ddi_fw/ml/evaluation_helper.py,sha256=2-7CLSgGTqLEk4HkgCVIOt-GxfLAn6SBozJghAtHb5
77
77
  ddi_fw/ml/ml_helper.py,sha256=E6ef7f1UnQl6JBUdGDbbbI4FIS-904VGypT7tI0a598,8545
78
78
  ddi_fw/ml/model_wrapper.py,sha256=kabPXuo7S8tGkp9a00V04n4rXDmv7dD8wYGMjotISRc,1050
79
79
  ddi_fw/ml/pytorch_wrapper.py,sha256=pe6UsjP2XeTgLxDnIUiodoyhJTGCxV27wD4Cjxysu2Q,8553
80
- ddi_fw/ml/tensorflow_wrapper.py,sha256=AeEXGbsQW6BgVf-Mgxe9NbvwNqLOqqCTGyTNxfg4G_Y,10564
80
+ ddi_fw/ml/tensorflow_wrapper.py,sha256=9WhjL00-WSzE0HJzr_T38IrD3ipCKGrrkeBGEEewUKw,10920
81
81
  ddi_fw/ner/__init__.py,sha256=JwhGXrepomxPSsGsg2b_xPRC72AjvxOIn2CW5Mvscn0,26
82
82
  ddi_fw/ner/mmlrestclient.py,sha256=NZta7m2Qm6I_qtVguMZhqtAUjVBmmXn0-TMnsNp0jpg,6859
83
83
  ddi_fw/ner/ner.py,sha256=FHyyX53Xwpdw8Hec261dyN88yD7Z9LmJua2mIrQLguI,17967
@@ -86,7 +86,8 @@ ddi_fw/pipeline/multi_modal_combination_strategy.py,sha256=JSyuP71b1I1yuk0s2ecCJ
86
86
  ddi_fw/pipeline/multi_pipeline.py,sha256=NfcH4Ze5U-JRiH3lrxEDWj-VPxYQYtp7tq6bLCImBzs,5550
87
87
  ddi_fw/pipeline/ner_pipeline.py,sha256=kNGtkg5rNX5MDywzvRxmvyk-DxXAjEbYzZkp8pNlAZo,6023
88
88
  ddi_fw/pipeline/pipeline.py,sha256=70lYsluAnTWDLTlf6rbecffw3Bl34L1_6ALfLUoSvtY,11324
89
- ddi_fw/utils/__init__.py,sha256=77563ikqAtdzjjgRlLp5OAsJBbpLA1Cao8iecGaVUXQ,354
89
+ ddi_fw/utils/__init__.py,sha256=L64M3YCB56eSiB-rE2Zmn6DMNNwqHPWHq6Z_f4fi9VQ,458
90
+ ddi_fw/utils/categorical_data_encoding_checker.py,sha256=DNbxjpyD8XTqILSHHQ0_VUd61PNBSupSAuXiq5nLTK8,1122
90
91
  ddi_fw/utils/enums.py,sha256=19eJ3fX5eRK_xPvkYcukmug144jXPH4X9zQqtsFBj5A,671
91
92
  ddi_fw/utils/json_helper.py,sha256=BVU6wmJgdXPxyqLPu3Ck_9Es5RrP1PDanKvE-OSj1D4,571
92
93
  ddi_fw/utils/kaggle.py,sha256=wKRJ18KpQ6P-CubpZklEgsDtyFpR9RUL1_HyyF6ttEE,2425
@@ -97,7 +98,7 @@ ddi_fw/utils/zip_helper.py,sha256=YRZA4tKZVBJwGQM0_WK6L-y5MoqkKoC-nXuuHK6CU9I,55
97
98
  ddi_fw/vectorization/__init__.py,sha256=LcJOpLVoLvHPDw9phGFlUQGeNcST_zKV-Oi1Pm5h_nE,110
98
99
  ddi_fw/vectorization/feature_vector_generation.py,sha256=Z1A_DOBqDFPqLN4YB-3oYlOQWJK-X6Oes6UFjpzR47Q,4760
99
100
  ddi_fw/vectorization/idf_helper.py,sha256=_Gd1dtDSLaw8o-o0JugzSKMt9FpeXewTh4wGEaUd4VQ,2571
100
- ddi_fw-0.0.179.dist-info/METADATA,sha256=vXo7V9eOR-nGS0HRf-TJPD-sZdKWhlYAP4ycIUlr_N8,2542
101
- ddi_fw-0.0.179.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
102
- ddi_fw-0.0.179.dist-info/top_level.txt,sha256=PMwHICFZTZtcpzQNPV4UQnfNXYIeLR_Ste-Wfc1h810,7
103
- ddi_fw-0.0.179.dist-info/RECORD,,
101
+ ddi_fw-0.0.181.dist-info/METADATA,sha256=dNVh9x2lxjHNUWnNkC6-cYm19JxkHNKfAwiYWfvVsro,2542
102
+ ddi_fw-0.0.181.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
103
+ ddi_fw-0.0.181.dist-info/top_level.txt,sha256=PMwHICFZTZtcpzQNPV4UQnfNXYIeLR_Ste-Wfc1h810,7
104
+ ddi_fw-0.0.181.dist-info/RECORD,,