ddi-fw 0.0.213__py3-none-any.whl → 0.0.214__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ddi_fw/ml/ml_helper.py CHANGED
@@ -156,7 +156,7 @@ class MultiModalRunner:
156
156
  (self.y_test_label.shape[0], self.y_test_label.shape[1]))
157
157
  for item in combination:
158
158
  prediction = prediction + single_results[item]
159
- prediction = np.argmax(prediction, axis=1)
159
+ prediction = utils.to_one_hot_encode(prediction)
160
160
  logs, metrics = evaluate(
161
161
  actual=self.y_test_label, pred=prediction, info=combination_descriptor)
162
162
  if self.use_mlflow:
ddi_fw/utils/__init__.py CHANGED
@@ -4,5 +4,5 @@ from .py7zr_helper import Py7ZipHelper
4
4
  from .enums import UMLSCodeTypes, DrugBankTextDataTypes
5
5
  from .package_helper import get_import
6
6
  from .kaggle import create_kaggle_dataset
7
- from .categorical_data_encoding_checker import is_one_hot_encoded, is_binary_encoded, is_binary_vector,is_label_encoded
7
+ from .categorical_data_encoding_checker import to_one_hot_encode,is_one_hot_encoded, is_binary_encoded, is_binary_vector,is_label_encoded
8
8
  from .numpy_utils import adjust_array_dims
@@ -1,5 +1,51 @@
1
1
  import numpy as np
2
2
 
3
+ def to_one_hot_encode(arr):
4
+ """
5
+ Convert a multi-dimensional array (1D, 2D, or 3D) into a one-hot encoded array.
6
+
7
+ Parameters:
8
+ arr (numpy.ndarray): An array where each element or row (or slice) contains class indices or class probabilities.
9
+
10
+ Returns:
11
+ numpy.ndarray: One-hot encoded array.
12
+ """
13
+ # Check if the input is a numpy array
14
+ if not isinstance(arr, np.ndarray):
15
+ raise ValueError("Input must be a numpy array")
16
+
17
+ # Get the shape of the input
18
+ shape = arr.shape
19
+
20
+ # If the array is 1D, treat it as a list of class indices
21
+ if arr.ndim == 1:
22
+ num_classes = np.max(arr) + 1
23
+ one_hot = np.zeros((shape[0], num_classes))
24
+ one_hot[np.arange(shape[0]), arr] = 1
25
+
26
+ # If the array is 2D, treat each row as a list of class indices
27
+ elif arr.ndim == 2:
28
+ num_classes = shape[1]
29
+ one_hot = np.zeros((shape[0], num_classes))
30
+ # Handle one-hot encoding for each row (max index for each row)
31
+ max_indices = np.argmax(arr, axis=1)
32
+ one_hot[np.arange(shape[0]), max_indices] = 1
33
+
34
+ # If the array is 3D or higher, iterate over the first axis and apply one-hot encoding for each slice
35
+ elif arr.ndim >= 3:
36
+ num_classes = shape[-1]
37
+ # Initialize the output array for one-hot encoding with the same shape as input, but the last dimension is num_classes
38
+ one_hot = np.zeros_like(arr, dtype=int)
39
+
40
+ # Iterate through the first axis (the batch dimension)
41
+ for i in range(shape[0]):
42
+ # Get the max indices along the second dimension for each slice
43
+ max_indices = np.argmax(arr[i], axis=1)
44
+ # Set the corresponding one-hot vectors in the slice
45
+ one_hot[i, np.arange(shape[1]), max_indices] = 1
46
+
47
+ return one_hot
48
+
3
49
 
4
50
  def is_one_hot_encoded(arr):
5
51
  # Check if the array is one-hot encoded
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ddi_fw
3
- Version: 0.0.213
3
+ Version: 0.0.214
4
4
  Summary: Do not use :)
5
5
  Author-email: Kıvanç Bayraktar <bayraktarkivanc@gmail.com>
6
6
  Maintainer-email: Kıvanç Bayraktar <bayraktarkivanc@gmail.com>
@@ -74,7 +74,7 @@ ddi_fw/langchain/sentence_splitter.py,sha256=h_bYElx4Ud1mwDNJfL7mUwvgadwKX3GKlSz
74
74
  ddi_fw/langchain/storage.py,sha256=OizKyWm74Js7T6Q9kez-ulUoBGzIMFo4R46h4kjUyIM,11200
75
75
  ddi_fw/ml/__init__.py,sha256=tIxiW0g6q1VsmDYVXR_ovvHQR3SCir8g2bKxx_CrS7s,221
76
76
  ddi_fw/ml/evaluation_helper.py,sha256=2-7CLSgGTqLEk4HkgCVIOt-GxfLAn6SBozJghAtHb5M,11581
77
- ddi_fw/ml/ml_helper.py,sha256=m6_yoZwkKgYh0RRlXExfBaE63H_UgeFOXW9Dzy1kVig,7710
77
+ ddi_fw/ml/ml_helper.py,sha256=6BO1ikCHmlYK9TPDN7Atov0BuTtoyLg06NoSGl3RYGA,7716
78
78
  ddi_fw/ml/model_wrapper.py,sha256=kabPXuo7S8tGkp9a00V04n4rXDmv7dD8wYGMjotISRc,1050
79
79
  ddi_fw/ml/pytorch_wrapper.py,sha256=pe6UsjP2XeTgLxDnIUiodoyhJTGCxV27wD4Cjxysu2Q,8553
80
80
  ddi_fw/ml/tensorflow_wrapper.py,sha256=Vw6M2rHDHV90jzfCr0XWpUqYVl4vmZeKsS7FUb3VkH4,12980
@@ -86,8 +86,8 @@ ddi_fw/pipeline/multi_modal_combination_strategy.py,sha256=JSyuP71b1I1yuk0s2ecCJ
86
86
  ddi_fw/pipeline/multi_pipeline.py,sha256=SZFJ9QSPD_3mcG9NHZOtMqKyNvyWrodsdsLryMyDdUw,8686
87
87
  ddi_fw/pipeline/ner_pipeline.py,sha256=Bp6BA6nozfWFaMHH6jKlzesnCGO6qiMkzdGy_ed6nh0,5947
88
88
  ddi_fw/pipeline/pipeline.py,sha256=YhUBVLC29ZD2tmVd0e8X1FVBLhSKECZL2OP57oEW6HE,9171
89
- ddi_fw/utils/__init__.py,sha256=HC32XkYQTYH_9vt0eX6tqQngEFG-R70hGrYkT-BcHCk,519
90
- ddi_fw/utils/categorical_data_encoding_checker.py,sha256=gzb_vUDBrCMUhBxY1fBYTe8hmK72p0_uw3DTga8cqP8,1580
89
+ ddi_fw/utils/__init__.py,sha256=WNxkQXk-694roG50D355TGLXstfdWVb_tUyr-PM-8rg,537
90
+ ddi_fw/utils/categorical_data_encoding_checker.py,sha256=T1X70Rh4atucAuqyUZmz-iFULllY9dY0NRyV9-jTjJ0,3438
91
91
  ddi_fw/utils/enums.py,sha256=19eJ3fX5eRK_xPvkYcukmug144jXPH4X9zQqtsFBj5A,671
92
92
  ddi_fw/utils/json_helper.py,sha256=BVU6wmJgdXPxyqLPu3Ck_9Es5RrP1PDanKvE-OSj1D4,571
93
93
  ddi_fw/utils/kaggle.py,sha256=wKRJ18KpQ6P-CubpZklEgsDtyFpR9RUL1_HyyF6ttEE,2425
@@ -99,7 +99,7 @@ ddi_fw/utils/zip_helper.py,sha256=YRZA4tKZVBJwGQM0_WK6L-y5MoqkKoC-nXuuHK6CU9I,55
99
99
  ddi_fw/vectorization/__init__.py,sha256=LcJOpLVoLvHPDw9phGFlUQGeNcST_zKV-Oi1Pm5h_nE,110
100
100
  ddi_fw/vectorization/feature_vector_generation.py,sha256=EBf-XAiwQwr68az91erEYNegfeqssBR29kVgrliIyac,4765
101
101
  ddi_fw/vectorization/idf_helper.py,sha256=_Gd1dtDSLaw8o-o0JugzSKMt9FpeXewTh4wGEaUd4VQ,2571
102
- ddi_fw-0.0.213.dist-info/METADATA,sha256=BjGWPBaTzKY--kJGul2QSnf5Sd96hVdwIlMJzDPE9Eo,2631
103
- ddi_fw-0.0.213.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
104
- ddi_fw-0.0.213.dist-info/top_level.txt,sha256=PMwHICFZTZtcpzQNPV4UQnfNXYIeLR_Ste-Wfc1h810,7
105
- ddi_fw-0.0.213.dist-info/RECORD,,
102
+ ddi_fw-0.0.214.dist-info/METADATA,sha256=IEDJdH40Nw4B0aJXnUwuxeNRdXMX5rw1RBsX93Zbj1A,2631
103
+ ddi_fw-0.0.214.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
104
+ ddi_fw-0.0.214.dist-info/top_level.txt,sha256=PMwHICFZTZtcpzQNPV4UQnfNXYIeLR_Ste-Wfc1h810,7
105
+ ddi_fw-0.0.214.dist-info/RECORD,,