ddi-fw 0.0.62__py3-none-any.whl → 0.0.63__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -85,10 +85,104 @@ def roc_aupr_score(y_true, y_score, average="macro"):
85
85
 
86
86
  return _average_binary_score(_binary_roc_aupr_score, y_true, y_score, average)
87
87
 
88
- # actual and pred are one-hot encoded
88
+
89
+ def evaluate(actual, pred, info='', print=False):
90
+ # Precompute y_true and y_pred
91
+ y_true = np.argmax(actual, axis=1)
92
+ y_pred = np.argmax(pred, axis=1)
93
+
94
+ # Generate classification report
95
+ c_report = classification_report(y_true, y_pred, output_dict=True)
96
+
97
+ # Metrics initialization
98
+ metrics = Metrics(info)
99
+
100
+ n_classes = actual.shape[1]
101
+
102
+ precision = {}
103
+ recall = {}
104
+ f_score = {}
105
+ roc_aupr = {}
106
+ roc_auc = {
107
+ "weighted": 0,
108
+ "macro": 0,
109
+ "micro": 0
110
+ }
111
+
112
+ # Preallocate lists
113
+ precision_vals = [[] for _ in range(n_classes)]
114
+ recall_vals = [[] for _ in range(n_classes)]
115
+
116
+ # Compute metrics for each class
117
+ for i in range(n_classes):
118
+ precision_vals[i], recall_vals[i], _ = precision_recall_curve(
119
+ actual[:, i], pred[:, i])
120
+ roc_aupr[i] = auc(recall_vals[i], precision_vals[i])
121
+
122
+ # Calculate ROC AUC scores
123
+ roc_auc["weighted"] = roc_auc_score(actual, pred, multi_class='ovr', average='weighted')
124
+ roc_auc["macro"] = roc_auc_score(actual, pred, multi_class='ovr', average='macro')
125
+ roc_auc["micro"] = roc_auc_score(actual, pred, multi_class='ovr', average='micro')
126
+
127
+ # Micro-average Precision-Recall curve and ROC-AUPR
128
+ precision["micro_event"], recall["micro_event"], _ = precision_recall_curve(actual.ravel(), pred.ravel())
129
+ roc_aupr["micro"] = auc(recall["micro_event"], precision["micro_event"])
130
+
131
+ # Convert lists to numpy arrays for better performance
132
+ precision["micro_event"] = precision["micro_event"].tolist()
133
+ recall["micro_event"] = recall["micro_event"].tolist()
134
+
135
+ # Overall accuracy
136
+ acc = accuracy_score(y_true, y_pred)
137
+
138
+ # Aggregate precision, recall, and f_score
139
+ for avg_type in ['weighted', 'macro', 'micro']:
140
+ precision[avg_type] = precision_score(y_true, y_pred, average=avg_type)
141
+ recall[avg_type] = recall_score(y_true, y_pred, average=avg_type)
142
+ f_score[avg_type] = f1_score(y_true, y_pred, average=avg_type)
143
+
144
+ if print:
145
+ print(
146
+ f'''Accuracy: {acc}
147
+ , Precision:{precision['weighted']}
148
+ , Recall: {recall['weighted']}
149
+ , F1-score: {f_score['weighted']}
150
+ ''')
151
+
152
+ logs = {'accuracy': acc,
153
+ 'weighted_precision': precision['weighted'],
154
+ 'macro_precision': precision['macro'],
155
+ 'micro_precision': precision['micro'],
156
+ 'weighted_recall_score': recall['weighted'],
157
+ 'macro_recall_score': recall['macro'],
158
+ 'micro_recall_score': recall['micro'],
159
+ 'weighted_f1_score': f_score['weighted'],
160
+ 'macro_f1_score': f_score['macro'],
161
+ 'micro_f1_score': f_score['micro'],
162
+ # 'weighted_roc_auc_score': weighted_roc_auc_score,
163
+ # 'macro_roc_auc_score': macro_roc_auc_score,
164
+ # 'micro_roc_auc_score': micro_roc_auc_score,
165
+ # 'macro_aupr_score': macro_aupr_score,
166
+ # 'micro_aupr_score': micro_aupr_score
167
+ "micro_roc_aupr": roc_aupr['micro'],
168
+ # "micro_precision_from_precision_recall_curve":precision["micro"],
169
+ # "micro_recall_from_precision_recall_curve":recall["micro"],
170
+ "weighted_roc_auc": roc_auc['weighted'],
171
+ "macro_roc_auc": roc_auc['macro'],
172
+ "micro_roc_auc": roc_auc['micro']
173
+ }
174
+ metrics.accuracy(acc)
175
+ metrics.precision(precision)
176
+ metrics.recall(recall)
177
+ metrics.f1_score(f_score)
178
+ metrics.roc_auc(roc_auc)
179
+ metrics.roc_aupr(roc_aupr)
180
+ metrics.classification_report(c_report)
181
+ return logs, metrics
89
182
 
90
183
 
91
- def evaluate(actual, pred, info = '' ,print=False):
184
+ # actual and pred are one-hot encoded
185
+ def evaluate_ex(actual, pred, info = '' ,print=False):
92
186
 
93
187
  y_pred = np.argmax(pred, axis=1)
94
188
  y_true = np.argmax(actual, axis=1)
@@ -33,7 +33,8 @@ class Experiment:
33
33
  experiment_tags=None,
34
34
  tracking_uri=None,
35
35
  dataset_type:BaseDataset=None,
36
- columns=None,
36
+ columns=None,
37
+ embedding_dict = None,
37
38
  vector_db_persist_directory=None,
38
39
  vector_db_collection_name=None,
39
40
  embedding_pooling_strategy_type:PoolingStrategy=None,
@@ -48,6 +49,7 @@ class Experiment:
48
49
  self.tracking_uri = tracking_uri
49
50
  self.dataset_type = dataset_type
50
51
  self.columns = columns
52
+ self.embedding_dict = embedding_dict
51
53
  self.vector_db_persist_directory = vector_db_persist_directory
52
54
  self.vector_db_collection_name = vector_db_collection_name
53
55
  self.embedding_pooling_strategy_type = embedding_pooling_strategy_type
@@ -61,21 +63,22 @@ class Experiment:
61
63
  kwargs = {"columns": self.columns}
62
64
  for k, v in self.ner_threshold.items():
63
65
  kwargs[k] = v
64
- if self.vector_db_persist_directory:
65
- self.vector_db = chromadb.PersistentClient(
66
- path=self.vector_db_persist_directory)
67
- self.collection = self.vector_db.get_collection(
68
- self.vector_db_collection_name)
69
- dictionary = self.collection.get(include=['embeddings', 'metadatas'])
66
+ if self.embedding_dict == None:
67
+ if self.vector_db_persist_directory:
68
+ self.vector_db = chromadb.PersistentClient(
69
+ path=self.vector_db_persist_directory)
70
+ self.collection = self.vector_db.get_collection(
71
+ self.vector_db_collection_name)
72
+ dictionary = self.collection.get(include=['embeddings', 'metadatas'])
70
73
 
71
- embedding_dict = defaultdict(lambda: defaultdict(list))
74
+ embedding_dict = defaultdict(lambda: defaultdict(list))
72
75
 
73
- for metadata, embedding in zip(dictionary['metadatas'], dictionary['embeddings']):
74
- embedding_dict[metadata["type"]][metadata["id"]].append(embedding)
76
+ for metadata, embedding in zip(dictionary['metadatas'], dictionary['embeddings']):
77
+ embedding_dict[metadata["type"]][metadata["id"]].append(embedding)
75
78
 
76
- embedding_size = dictionary['embeddings'].shape[1]
79
+ embedding_size = dictionary['embeddings'].shape[1]
77
80
 
78
- pooling_strategy = self.embedding_pooling_strategy_type()
81
+ pooling_strategy = self.embedding_pooling_strategy_type()
79
82
 
80
83
  self.ner_df = CTakesNER().load(filename=self.ner_data_file) if self.ner_data_file else None
81
84
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ddi_fw
3
- Version: 0.0.62
3
+ Version: 0.0.63
4
4
  Summary: Do not use :)
5
5
  Author-email: Kıvanç Bayraktar <bayraktarkivanc@gmail.com>
6
6
  Maintainer-email: Kıvanç Bayraktar <bayraktarkivanc@gmail.com>
@@ -58,8 +58,8 @@ ddi_fw/drugbank/drugbank_processor_org.py,sha256=eO5Yset50P91qkic79RUXPoEuxRxQKF
58
58
  ddi_fw/drugbank/event_extractor.py,sha256=6odoZohhK7OdLF-LF0l-5BFq0_NMG_5jrFJbHrBXsI8,4600
59
59
  ddi_fw/experiments/__init__.py,sha256=5L2xSolpFycNnflqOMdvJSiqRB16ExA5bbVGORKFX04,195
60
60
  ddi_fw/experiments/custom_torch_model.py,sha256=iQ_R_EApzD2JCcASN8cie6D21oh7VCxaOQ45_dkiGwc,2576
61
- ddi_fw/experiments/evaluation_helper.py,sha256=pY69cezV3WzrXw1bduIwRJfah1w3wXJ2YyTNim1J7ko,9349
62
- ddi_fw/experiments/pipeline.py,sha256=wttkvdzGP9d3jC9nx2iZul4hbogXkRho6eDns0yfLiE,5380
61
+ ddi_fw/experiments/evaluation_helper.py,sha256=o4-w5Xa3t4olLW4ymx_8L-Buhe5wfQEmT2bh4Zz544c,13066
62
+ ddi_fw/experiments/pipeline.py,sha256=VsKPgYsGTY2bYIajRBAewBgP9-izmrL0Qtbn48qV5tw,5544
63
63
  ddi_fw/experiments/pipeline_builder_pattern.py,sha256=q1PNEQFoO5U3UidEoGB8rgLA7KXr4FsJTXEug5c5UJg,5466
64
64
  ddi_fw/experiments/pipeline_ner.py,sha256=unxEJCYrG6wEZjLmqvGdLRTMOBwELbGKkdygSpAR3b8,5043
65
65
  ddi_fw/experiments/tensorflow_helper.py,sha256=xUnbntWyc2Wm4TvmVFAnpwLHg-o13oM26GUHom6d5m0,11776
@@ -83,7 +83,7 @@ ddi_fw/utils/enums.py,sha256=19eJ3fX5eRK_xPvkYcukmug144jXPH4X9zQqtsFBj5A,671
83
83
  ddi_fw/utils/py7zr_helper.py,sha256=gOqaFIyJvTjUM-btO2x9AQ69jZOS8PoKN0wetYIckJw,4747
84
84
  ddi_fw/utils/utils.py,sha256=szwnxMTDRrZoeNRyDuf3aCbtzriwtaRk4mHSH3asLdA,4301
85
85
  ddi_fw/utils/zip_helper.py,sha256=YRZA4tKZVBJwGQM0_WK6L-y5MoqkKoC-nXuuHK6CU9I,5567
86
- ddi_fw-0.0.62.dist-info/METADATA,sha256=Osa0PYBMQcu8Pshz-QZ-uJ8lEcOU12zl0DmeXtCxREE,1565
87
- ddi_fw-0.0.62.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
88
- ddi_fw-0.0.62.dist-info/top_level.txt,sha256=PMwHICFZTZtcpzQNPV4UQnfNXYIeLR_Ste-Wfc1h810,7
89
- ddi_fw-0.0.62.dist-info/RECORD,,
86
+ ddi_fw-0.0.63.dist-info/METADATA,sha256=5fuV5oU6k1S0RAqMYGxR3nGuyV0lXpcayGsd9ydsEmI,1565
87
+ ddi_fw-0.0.63.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
88
+ ddi_fw-0.0.63.dist-info/top_level.txt,sha256=PMwHICFZTZtcpzQNPV4UQnfNXYIeLR_Ste-Wfc1h810,7
89
+ ddi_fw-0.0.63.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.1.0)
2
+ Generator: setuptools (75.2.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5