ddi-fw 0.0.172__py3-none-any.whl → 0.0.173__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ddi_fw/datasets/core.py +3 -2
- ddi_fw/ml/evaluation_helper.py +132 -31
- ddi_fw/ml/tensorflow_wrapper.py +4 -0
- {ddi_fw-0.0.172.dist-info → ddi_fw-0.0.173.dist-info}/METADATA +1 -1
- {ddi_fw-0.0.172.dist-info → ddi_fw-0.0.173.dist-info}/RECORD +7 -7
- {ddi_fw-0.0.172.dist-info → ddi_fw-0.0.173.dist-info}/WHEEL +0 -0
- {ddi_fw-0.0.172.dist-info → ddi_fw-0.0.173.dist-info}/top_level.txt +0 -0
ddi_fw/datasets/core.py
CHANGED
@@ -180,12 +180,13 @@ class BaseDataset(BaseModel):
|
|
180
180
|
if self.X_train is not None or self.X_test is not None:
|
181
181
|
raise Exception(
|
182
182
|
"X_train and X_test are already present. Splitting is not allowed.")
|
183
|
-
|
183
|
+
|
184
|
+
self.prep()
|
184
185
|
if self.dataframe is None:
|
185
186
|
raise Exception("There is no dataframe to split.")
|
186
187
|
|
187
188
|
save_path = self.index_path
|
188
|
-
|
189
|
+
|
189
190
|
|
190
191
|
X = self.dataframe.drop(self.class_column, axis=1)
|
191
192
|
y = self.dataframe[self.class_column]
|
ddi_fw/ml/evaluation_helper.py
CHANGED
@@ -11,6 +11,7 @@ from sklearn.metrics import auc
|
|
11
11
|
from sklearn.metrics import classification_report
|
12
12
|
from sklearn.preprocessing import OneHotEncoder
|
13
13
|
|
14
|
+
|
14
15
|
def __format__(d: Union[Dict[str, Union[List[float], float]], float], floating_number_precision=4) -> Union[Dict[str, Union[List[float], float]], float]:
|
15
16
|
if isinstance(d, dict):
|
16
17
|
d = {k: __round__(v, floating_number_precision) for k, v in d.items()}
|
@@ -18,25 +19,25 @@ def __format__(d: Union[Dict[str, Union[List[float], float]], float], floating_n
|
|
18
19
|
d = round(d, floating_number_precision)
|
19
20
|
return d
|
20
21
|
|
21
|
-
|
22
|
+
|
23
|
+
def __round__(v, floating_number_precision=4) -> Union[List[float], float]:
|
22
24
|
if type(v) is list or type(v) is set:
|
23
|
-
|
25
|
+
return [round(item, floating_number_precision) for item in v]
|
24
26
|
else:
|
25
|
-
return round(v,floating_number_precision)
|
27
|
+
return round(v, floating_number_precision)
|
26
28
|
|
27
29
|
|
28
30
|
class Metrics(BaseModel):
|
29
31
|
label: str
|
30
32
|
accuracy: float = 0.0
|
31
|
-
precision: Any= None
|
32
|
-
recall: Any= None
|
33
|
-
f1_score: Any= None
|
34
|
-
roc_auc: Any= None
|
35
|
-
roc_aupr: Any= None
|
33
|
+
precision: Any = None
|
34
|
+
recall: Any = None
|
35
|
+
f1_score: Any = None
|
36
|
+
roc_auc: Any = None
|
37
|
+
roc_aupr: Any = None
|
36
38
|
classification_report: Any = None
|
37
39
|
|
38
|
-
|
39
|
-
def set_classification_report(self,classification_report):
|
40
|
+
def set_classification_report(self, classification_report):
|
40
41
|
self.classification_report = classification_report
|
41
42
|
|
42
43
|
def set_accuracy(self, accuracy):
|
@@ -57,15 +58,14 @@ class Metrics(BaseModel):
|
|
57
58
|
def set_roc_aupr(self, roc_aupr):
|
58
59
|
self.roc_aupr = roc_aupr
|
59
60
|
|
60
|
-
def format_float(self, floating_number_precision
|
61
|
-
self.accuracy = round(self.accuracy,floating_number_precision)
|
62
|
-
self.precision = __format__(
|
63
|
-
self.recall = __format__(
|
64
|
-
self.f1_score = __format__(
|
65
|
-
self.roc_auc = __format__(
|
66
|
-
self.roc_aupr = __format__(
|
61
|
+
def format_float(self, floating_number_precision=4):
|
62
|
+
self.accuracy = round(self.accuracy, floating_number_precision)
|
63
|
+
self.precision = __format__(self.precision, floating_number_precision)
|
64
|
+
self.recall = __format__(self.recall, floating_number_precision)
|
65
|
+
self.f1_score = __format__(self.f1_score, floating_number_precision)
|
66
|
+
self.roc_auc = __format__(self.roc_auc, floating_number_precision)
|
67
|
+
self.roc_aupr = __format__(self.roc_aupr, floating_number_precision)
|
67
68
|
|
68
|
-
|
69
69
|
|
70
70
|
# taken from https://github.com/YifanDengWHU/DDIMDL/blob/master/DDIMDL.py#L214
|
71
71
|
def roc_aupr_score(y_true, y_score, average="macro"):
|
@@ -96,19 +96,119 @@ def roc_aupr_score(y_true, y_score, average="macro"):
|
|
96
96
|
return _average_binary_score(_binary_roc_aupr_score, y_true, y_score, average)
|
97
97
|
|
98
98
|
|
99
|
-
def evaluate(actual, pred, info='', print_detail=False):
|
99
|
+
def evaluate(actual: np.ndarray, pred: np.ndarray, info='', print_detail=False):
|
100
|
+
y_true = actual
|
101
|
+
y_pred = pred
|
102
|
+
|
103
|
+
# Generate classification report
|
104
|
+
c_report = classification_report(y_true, y_pred, output_dict=True)
|
105
|
+
|
106
|
+
# Metrics initialization
|
107
|
+
metrics = Metrics(label=info)
|
108
|
+
|
109
|
+
n_classes = actual.shape[1]
|
110
|
+
# n_classes = len(np.unique(actual))
|
111
|
+
|
112
|
+
precision = {}
|
113
|
+
recall = {}
|
114
|
+
f_score = {}
|
115
|
+
roc_aupr = {}
|
116
|
+
roc_auc = {
|
117
|
+
"weighted": 0.0,
|
118
|
+
"macro": 0.0,
|
119
|
+
"micro": 0.0
|
120
|
+
}
|
121
|
+
|
122
|
+
# Preallocate lists
|
123
|
+
precision_vals: List[np.ndarray] = [np.array([]) for _ in range(n_classes)]
|
124
|
+
recall_vals: List[np.ndarray] = [np.array([]) for _ in range(n_classes)]
|
125
|
+
|
126
|
+
# Compute metrics for each class
|
127
|
+
for i in range(n_classes):
|
128
|
+
precision_vals[i], recall_vals[i], _ = precision_recall_curve(
|
129
|
+
actual[:, i], pred[:, i])
|
130
|
+
roc_aupr[i] = auc(recall_vals[i], precision_vals[i])
|
131
|
+
|
132
|
+
# Calculate ROC AUC scores
|
133
|
+
roc_auc["weighted"] = float(roc_auc_score(
|
134
|
+
actual, pred, multi_class='ovr', average='weighted'))
|
135
|
+
roc_auc["macro"] = float(roc_auc_score(
|
136
|
+
actual, pred, multi_class='ovr', average='macro'))
|
137
|
+
roc_auc["micro"] = float(roc_auc_score(
|
138
|
+
actual, pred, multi_class='ovr', average='micro'))
|
139
|
+
|
140
|
+
# Micro-average Precision-Recall curve and ROC-AUPR
|
141
|
+
precision["micro_event"], recall["micro_event"], _ = precision_recall_curve(
|
142
|
+
actual.ravel(), pred.ravel())
|
143
|
+
roc_aupr["micro"] = auc(recall["micro_event"], precision["micro_event"])
|
144
|
+
|
145
|
+
# Convert lists to numpy arrays for better performance
|
146
|
+
precision["micro_event"] = precision["micro_event"].tolist()
|
147
|
+
recall["micro_event"] = recall["micro_event"].tolist()
|
148
|
+
|
149
|
+
# Overall accuracy
|
150
|
+
acc = accuracy_score(y_true, y_pred)
|
151
|
+
|
152
|
+
# Aggregate precision, recall, and f_score
|
153
|
+
# for avg_type in ['weighted', 'macro', 'micro']:
|
154
|
+
for avg_type in Literal['weighted', 'macro', 'micro'].__args__:
|
155
|
+
precision[avg_type] = precision_score(y_true, y_pred, average=avg_type)
|
156
|
+
recall[avg_type] = recall_score(y_true, y_pred, average=avg_type)
|
157
|
+
f_score[avg_type] = f1_score(y_true, y_pred, average=avg_type)
|
158
|
+
|
159
|
+
if print_detail:
|
160
|
+
print(
|
161
|
+
f'''Accuracy: {acc}
|
162
|
+
, Precision:{precision['weighted']}
|
163
|
+
, Recall: {recall['weighted']}
|
164
|
+
, F1-score: {f_score['weighted']}
|
165
|
+
''')
|
166
|
+
|
167
|
+
logs = {'accuracy': acc,
|
168
|
+
'weighted_precision': precision['weighted'],
|
169
|
+
'macro_precision': precision['macro'],
|
170
|
+
'micro_precision': precision['micro'],
|
171
|
+
'weighted_recall_score': recall['weighted'],
|
172
|
+
'macro_recall_score': recall['macro'],
|
173
|
+
'micro_recall_score': recall['micro'],
|
174
|
+
'weighted_f1_score': f_score['weighted'],
|
175
|
+
'macro_f1_score': f_score['macro'],
|
176
|
+
'micro_f1_score': f_score['micro'],
|
177
|
+
# 'weighted_roc_auc_score': weighted_roc_auc_score,
|
178
|
+
# 'macro_roc_auc_score': macro_roc_auc_score,
|
179
|
+
# 'micro_roc_auc_score': micro_roc_auc_score,
|
180
|
+
# 'macro_aupr_score': macro_aupr_score,
|
181
|
+
# 'micro_aupr_score': micro_aupr_score
|
182
|
+
"micro_roc_aupr": roc_aupr['micro'],
|
183
|
+
# "micro_precision_from_precision_recall_curve":precision["micro"],
|
184
|
+
# "micro_recall_from_precision_recall_curve":recall["micro"],
|
185
|
+
"weighted_roc_auc": roc_auc['weighted'],
|
186
|
+
"macro_roc_auc": roc_auc['macro'],
|
187
|
+
"micro_roc_auc": roc_auc['micro']
|
188
|
+
}
|
189
|
+
metrics.set_accuracy(acc)
|
190
|
+
metrics.set_precision(precision)
|
191
|
+
metrics.set_recall(recall)
|
192
|
+
metrics.set_f1_score(f_score)
|
193
|
+
metrics.set_roc_auc(roc_auc)
|
194
|
+
metrics.set_roc_aupr(roc_aupr)
|
195
|
+
metrics.set_classification_report(c_report)
|
196
|
+
return logs, metrics
|
197
|
+
|
198
|
+
|
199
|
+
def evaluate_ex(actual, pred, info='', print_detail=False):
|
100
200
|
# Precompute y_true and y_pred
|
101
201
|
y_true = np.argmax(actual, axis=1)
|
102
202
|
y_pred = np.argmax(pred, axis=1)
|
103
|
-
|
203
|
+
|
104
204
|
# Generate classification report
|
105
205
|
c_report = classification_report(y_true, y_pred, output_dict=True)
|
106
|
-
|
206
|
+
|
107
207
|
# Metrics initialization
|
108
|
-
metrics = Metrics(label=
|
109
|
-
|
208
|
+
metrics = Metrics(label=info)
|
209
|
+
|
110
210
|
n_classes = actual.shape[1]
|
111
|
-
|
211
|
+
|
112
212
|
precision = {}
|
113
213
|
recall = {}
|
114
214
|
f_score = {}
|
@@ -123,7 +223,6 @@ def evaluate(actual, pred, info='', print_detail=False):
|
|
123
223
|
precision_vals: List[np.ndarray] = [np.array([]) for _ in range(n_classes)]
|
124
224
|
recall_vals: List[np.ndarray] = [np.array([]) for _ in range(n_classes)]
|
125
225
|
|
126
|
-
|
127
226
|
# Compute metrics for each class
|
128
227
|
for i in range(n_classes):
|
129
228
|
precision_vals[i], recall_vals[i], _ = precision_recall_curve(
|
@@ -131,12 +230,16 @@ def evaluate(actual, pred, info='', print_detail=False):
|
|
131
230
|
roc_aupr[i] = auc(recall_vals[i], precision_vals[i])
|
132
231
|
|
133
232
|
# Calculate ROC AUC scores
|
134
|
-
roc_auc["weighted"] = float(roc_auc_score(
|
135
|
-
|
136
|
-
roc_auc["
|
233
|
+
roc_auc["weighted"] = float(roc_auc_score(
|
234
|
+
actual, pred, multi_class='ovr', average='weighted'))
|
235
|
+
roc_auc["macro"] = float(roc_auc_score(
|
236
|
+
actual, pred, multi_class='ovr', average='macro'))
|
237
|
+
roc_auc["micro"] = float(roc_auc_score(
|
238
|
+
actual, pred, multi_class='ovr', average='micro'))
|
137
239
|
|
138
240
|
# Micro-average Precision-Recall curve and ROC-AUPR
|
139
|
-
precision["micro_event"], recall["micro_event"], _ = precision_recall_curve(
|
241
|
+
precision["micro_event"], recall["micro_event"], _ = precision_recall_curve(
|
242
|
+
actual.ravel(), pred.ravel())
|
140
243
|
roc_aupr["micro"] = auc(recall["micro_event"], precision["micro_event"])
|
141
244
|
|
142
245
|
# Convert lists to numpy arrays for better performance
|
@@ -191,5 +294,3 @@ def evaluate(actual, pred, info='', print_detail=False):
|
|
191
294
|
metrics.set_roc_aupr(roc_aupr)
|
192
295
|
metrics.set_classification_report(c_report)
|
193
296
|
return logs, metrics
|
194
|
-
|
195
|
-
|
ddi_fw/ml/tensorflow_wrapper.py
CHANGED
@@ -30,6 +30,7 @@ class TFModelWrapper(ModelWrapper):
|
|
30
30
|
|
31
31
|
def fit_model(self, X_train, y_train, X_valid, y_valid):
|
32
32
|
self.kwargs['input_shape'] = self.train_data.shape
|
33
|
+
self.num_classes = len(np.unique(y_train))
|
33
34
|
model = self.model_func(**self.kwargs)
|
34
35
|
checkpoint = ModelCheckpoint(
|
35
36
|
filepath=f'{self.descriptor}_validation.weights.h5',
|
@@ -135,6 +136,9 @@ class TFModelWrapper(ModelWrapper):
|
|
135
136
|
print(best_model_key)
|
136
137
|
self.best_model: Model = best_model
|
137
138
|
pred = self.predict()
|
139
|
+
pred = tf.keras.utils.to_categorical(np.argmax(pred,axis=1), num_classes=self.num_classes)
|
140
|
+
actual = tf.keras.utils.to_categorical(self.test_label, num_classes=self.num_classes)
|
141
|
+
|
138
142
|
logs, metrics = evaluate(
|
139
143
|
actual=self.test_label, pred=pred, info=self.descriptor, print_detail=print_detail)
|
140
144
|
metrics.format_float()
|
@@ -1,5 +1,5 @@
|
|
1
1
|
ddi_fw/datasets/__init__.py,sha256=_I3iDHARwzmg7_EL5XKtB_TgG1yAkLSOVTujLL9Wz9Q,280
|
2
|
-
ddi_fw/datasets/core.py,sha256=
|
2
|
+
ddi_fw/datasets/core.py,sha256=j6YpH6IqPQ2va1cC26xT-Jn3fIPsF43xD3GuluJRJb4,9372
|
3
3
|
ddi_fw/datasets/dataset_splitter.py,sha256=8H8uZTAf8N9LUZeSeHOMawtJFJhnDgUUqFcnl7dquBQ,1672
|
4
4
|
ddi_fw/datasets/db_utils.py,sha256=OTsa3d-Iic7z3HmzSQK9UigedRbHDxYChJk0s4GfLnw,6191
|
5
5
|
ddi_fw/datasets/setup_._py,sha256=khYVJuW5PlOY_i_A16F3UbSZ6s6o_ljw33Byw3C-A8E,1047
|
@@ -73,11 +73,11 @@ ddi_fw/langchain/embeddings.py,sha256=eEWy4okcjdhUJHi4N48Wd8XauPXyeaQVLUdNWEvtEc
|
|
73
73
|
ddi_fw/langchain/sentence_splitter.py,sha256=h_bYElx4Ud1mwDNJfL7mUwvgadwKX3GKlSzu5L2PXzg,280
|
74
74
|
ddi_fw/langchain/storage.py,sha256=OizKyWm74Js7T6Q9kez-ulUoBGzIMFo4R46h4kjUyIM,11200
|
75
75
|
ddi_fw/ml/__init__.py,sha256=tIxiW0g6q1VsmDYVXR_ovvHQR3SCir8g2bKxx_CrS7s,221
|
76
|
-
ddi_fw/ml/evaluation_helper.py,sha256=
|
76
|
+
ddi_fw/ml/evaluation_helper.py,sha256=2-7CLSgGTqLEk4HkgCVIOt-GxfLAn6SBozJghAtHb5M,11581
|
77
77
|
ddi_fw/ml/ml_helper.py,sha256=E6ef7f1UnQl6JBUdGDbbbI4FIS-904VGypT7tI0a598,8545
|
78
78
|
ddi_fw/ml/model_wrapper.py,sha256=kabPXuo7S8tGkp9a00V04n4rXDmv7dD8wYGMjotISRc,1050
|
79
79
|
ddi_fw/ml/pytorch_wrapper.py,sha256=pe6UsjP2XeTgLxDnIUiodoyhJTGCxV27wD4Cjxysu2Q,8553
|
80
|
-
ddi_fw/ml/tensorflow_wrapper.py,sha256=
|
80
|
+
ddi_fw/ml/tensorflow_wrapper.py,sha256=IQq0KSU-WuRI90b3DcZ8vhxATfZgdymkAqiiz4a1D6g,10377
|
81
81
|
ddi_fw/ner/__init__.py,sha256=JwhGXrepomxPSsGsg2b_xPRC72AjvxOIn2CW5Mvscn0,26
|
82
82
|
ddi_fw/ner/mmlrestclient.py,sha256=NZta7m2Qm6I_qtVguMZhqtAUjVBmmXn0-TMnsNp0jpg,6859
|
83
83
|
ddi_fw/ner/ner.py,sha256=FHyyX53Xwpdw8Hec261dyN88yD7Z9LmJua2mIrQLguI,17967
|
@@ -97,7 +97,7 @@ ddi_fw/utils/zip_helper.py,sha256=YRZA4tKZVBJwGQM0_WK6L-y5MoqkKoC-nXuuHK6CU9I,55
|
|
97
97
|
ddi_fw/vectorization/__init__.py,sha256=LcJOpLVoLvHPDw9phGFlUQGeNcST_zKV-Oi1Pm5h_nE,110
|
98
98
|
ddi_fw/vectorization/feature_vector_generation.py,sha256=Z1A_DOBqDFPqLN4YB-3oYlOQWJK-X6Oes6UFjpzR47Q,4760
|
99
99
|
ddi_fw/vectorization/idf_helper.py,sha256=_Gd1dtDSLaw8o-o0JugzSKMt9FpeXewTh4wGEaUd4VQ,2571
|
100
|
-
ddi_fw-0.0.
|
101
|
-
ddi_fw-0.0.
|
102
|
-
ddi_fw-0.0.
|
103
|
-
ddi_fw-0.0.
|
100
|
+
ddi_fw-0.0.173.dist-info/METADATA,sha256=4HVYwgrsyel7JO4cJ3pZTtw5G_YwmRsrNyIClsmJaFo,2542
|
101
|
+
ddi_fw-0.0.173.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
102
|
+
ddi_fw-0.0.173.dist-info/top_level.txt,sha256=PMwHICFZTZtcpzQNPV4UQnfNXYIeLR_Ste-Wfc1h810,7
|
103
|
+
ddi_fw-0.0.173.dist-info/RECORD,,
|
File without changes
|
File without changes
|