ddi-fw 0.0.225__py3-none-any.whl → 0.0.227__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ddi_fw/datasets/core.py +3 -2
- ddi_fw/ml/tensorflow_wrapper.py +11 -1
- ddi_fw/pipeline/multi_pipeline.py +4 -4
- ddi_fw/pipeline/ner_pipeline.py +53 -29
- {ddi_fw-0.0.225.dist-info → ddi_fw-0.0.227.dist-info}/METADATA +2 -2
- {ddi_fw-0.0.225.dist-info → ddi_fw-0.0.227.dist-info}/RECORD +8 -8
- {ddi_fw-0.0.225.dist-info → ddi_fw-0.0.227.dist-info}/WHEEL +1 -1
- {ddi_fw-0.0.225.dist-info → ddi_fw-0.0.227.dist-info}/top_level.txt +0 -0
ddi_fw/datasets/core.py
CHANGED
@@ -4,7 +4,7 @@ import glob
|
|
4
4
|
import logging
|
5
5
|
from typing import Any, Dict, List, Optional, Type
|
6
6
|
import chromadb
|
7
|
-
from chromadb.api.types import IncludeEnum
|
7
|
+
# from chromadb.api.types import IncludeEnum
|
8
8
|
import numpy as np
|
9
9
|
import pandas as pd
|
10
10
|
from pydantic import BaseModel, Field, computed_field
|
@@ -307,7 +307,8 @@ class TextDatasetMixin(BaseModel):
|
|
307
307
|
vector_db = chromadb.PersistentClient(
|
308
308
|
path=vector_db_persist_directory)
|
309
309
|
collection = vector_db.get_collection(vector_db_collection_name)
|
310
|
-
include = [IncludeEnum.embeddings, IncludeEnum.metadatas]
|
310
|
+
# include = [IncludeEnum.embeddings, IncludeEnum.metadatas]
|
311
|
+
include: chromadb.Include = ["embeddings","metadatas"]
|
311
312
|
dictionary: chromadb.GetResult
|
312
313
|
# Fetch the embeddings and metadata
|
313
314
|
if column == None:
|
ddi_fw/ml/tensorflow_wrapper.py
CHANGED
@@ -102,9 +102,18 @@ class TFModelWrapper(ModelWrapper):
|
|
102
102
|
history = model.fit(
|
103
103
|
train_dataset,
|
104
104
|
epochs=self.epochs,
|
105
|
-
|
105
|
+
validation_data=val_dataset,
|
106
106
|
callbacks=callbacks
|
107
107
|
)
|
108
|
+
|
109
|
+
# Check if early stopping was applied
|
110
|
+
if early_stopping.stopped_epoch > 0:
|
111
|
+
print(f"Early stopping was applied at epoch {early_stopping.stopped_epoch}.")
|
112
|
+
else:
|
113
|
+
print("Early stopping was not applied.")
|
114
|
+
if self.tracking_service:
|
115
|
+
self.tracking_service.log_param("early_stopping_applied", early_stopping.stopped_epoch > 0)
|
116
|
+
self.tracking_service.log_param("early_stopping_epoch", early_stopping.stopped_epoch)
|
108
117
|
# ex
|
109
118
|
# history = model.fit(
|
110
119
|
# X_train, y_train,
|
@@ -174,6 +183,7 @@ class TFModelWrapper(ModelWrapper):
|
|
174
183
|
if models_val_acc == {}:
|
175
184
|
return model, None
|
176
185
|
best_model_key = max(models_val_acc, key=lambda k: models_val_acc[k])
|
186
|
+
print("best model key: ", best_model_key)
|
177
187
|
# best_model_key = max(models_val_acc, key=models_val_acc.get)
|
178
188
|
best_model = models[best_model_key]
|
179
189
|
return best_model, best_model_key
|
@@ -191,15 +191,15 @@ class MultiPipeline():
|
|
191
191
|
elif type== "ner_search":
|
192
192
|
pipeline = NerParameterSearch(
|
193
193
|
library=library,
|
194
|
+
tracking_library=tracking_library,
|
195
|
+
tracking_params=tracking_params,
|
194
196
|
experiment_name=experiment_name,
|
195
197
|
experiment_description=experiment_description,
|
196
|
-
experiment_tags=experiment_tags,
|
197
|
-
tracking_uri=tracking_uri,
|
198
198
|
dataset_type=dataset_type,
|
199
|
+
dataset_additional_config=additional_config,
|
199
200
|
umls_code_types = None,
|
200
201
|
text_types = None,
|
201
|
-
columns=
|
202
|
-
ner_data_file=ner_data_file,
|
202
|
+
columns=columns,
|
203
203
|
multi_modal= multi_modal
|
204
204
|
)
|
205
205
|
|
ddi_fw/pipeline/ner_pipeline.py
CHANGED
@@ -10,19 +10,21 @@ from ddi_fw.vectorization.idf_helper import IDF
|
|
10
10
|
from ddi_fw.ner.ner import CTakesNER
|
11
11
|
from ddi_fw.ml.ml_helper import MultiModalRunner
|
12
12
|
from ddi_fw.utils.enums import DrugBankTextDataTypes, UMLSCodeTypes
|
13
|
+
import logging
|
13
14
|
|
14
|
-
|
15
|
+
|
15
16
|
class NerParameterSearch(BaseModel):
|
16
17
|
library: str
|
17
18
|
default_model: Optional[Any] = None
|
18
19
|
multi_modal: Optional[Any] = None
|
19
20
|
experiment_name: str
|
20
21
|
experiment_description: Optional[str] = None
|
21
|
-
|
22
|
-
|
22
|
+
tracking_library: str
|
23
|
+
tracking_params: Optional[Dict[str, Any]] = None
|
24
|
+
dataset_type: Type[BaseDataset]
|
25
|
+
dataset_additional_config: Optional[Dict[str, Any]] = None
|
23
26
|
dataset_type: Type[BaseDataset]
|
24
27
|
dataset_splitter_type: Type[DatasetSplitter] = DatasetSplitter
|
25
|
-
ner_data_file: Optional[str] = None
|
26
28
|
columns: List[str] = Field(default_factory=list)
|
27
29
|
umls_code_types: Optional[List[UMLSCodeTypes]] = None
|
28
30
|
text_types: Optional[List[DrugBankTextDataTypes]] = None
|
@@ -33,7 +35,7 @@ class NerParameterSearch(BaseModel):
|
|
33
35
|
# Internal fields (not part of the input)
|
34
36
|
datasets: Dict[str, Any] = Field(default_factory=dict, exclude=True)
|
35
37
|
items: List[Any] = Field(default_factory=list, exclude=True)
|
36
|
-
ner_df: Optional[Any] = Field(default=None, exclude=True)
|
38
|
+
# ner_df: Optional[Any] = Field(default=None, exclude=True)
|
37
39
|
train_idx_arr: Optional[List[np.ndarray]] = Field(default=None, exclude=True)
|
38
40
|
val_idx_arr: Optional[List[np.ndarray]] = Field(default=None, exclude=True)
|
39
41
|
y_test_label: Optional[np.ndarray] = Field(default=None, exclude=True)
|
@@ -64,12 +66,17 @@ class NerParameterSearch(BaseModel):
|
|
64
66
|
raise TypeError("self.dataset_type must be a class, not an instance")
|
65
67
|
|
66
68
|
# Load NER data
|
67
|
-
|
68
|
-
|
69
|
+
ner_data_file = (
|
70
|
+
self.dataset_additional_config.get("ner", {}).get("data_file")
|
71
|
+
if self.dataset_additional_config else None
|
72
|
+
)
|
73
|
+
|
74
|
+
if ner_data_file:
|
75
|
+
ner_df = CTakesNER(df=None).load(filename=ner_data_file)
|
69
76
|
|
70
77
|
# Initialize thresholds if not provided
|
71
78
|
if not self.min_threshold_dict or not self.max_threshold_dict:
|
72
|
-
idf = IDF(
|
79
|
+
idf = IDF(ner_df, self.columns)
|
73
80
|
idf.calculate()
|
74
81
|
df = idf.to_dataframe()
|
75
82
|
self.min_threshold_dict = {key: np.floor(df.describe()[key]["min"]) for key in df.describe().keys()}
|
@@ -85,6 +92,8 @@ class NerParameterSearch(BaseModel):
|
|
85
92
|
"cui_threshold": 0,
|
86
93
|
"entities_threshold": 0,
|
87
94
|
}
|
95
|
+
if self.dataset_additional_config:
|
96
|
+
kwargs["additional_config"]= self.dataset_additional_config
|
88
97
|
|
89
98
|
for threshold in np.arange(min_threshold, max_threshold, self.increase_step):
|
90
99
|
if column.startswith("tui"):
|
@@ -93,10 +102,9 @@ class NerParameterSearch(BaseModel):
|
|
93
102
|
kwargs["cui_threshold"] = threshold
|
94
103
|
if column.startswith("entities"):
|
95
104
|
kwargs["entities_threshold"] = threshold
|
96
|
-
|
105
|
+
|
97
106
|
dataset = self.dataset_type(
|
98
107
|
columns=[column],
|
99
|
-
ner_df=self.ner_df,
|
100
108
|
dataset_splitter_type=self.dataset_splitter_type,
|
101
109
|
**kwargs,
|
102
110
|
)
|
@@ -113,22 +121,38 @@ class NerParameterSearch(BaseModel):
|
|
113
121
|
self.train_idx_arr = dataset.train_idx_arr
|
114
122
|
self.val_idx_arr = dataset.val_idx_arr
|
115
123
|
|
116
|
-
def run(self):
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
124
|
+
# def run(self):
|
125
|
+
# """Run the parameter search."""
|
126
|
+
# mlflow.set_tracking_uri(self.tracking_uri)
|
127
|
+
|
128
|
+
# if mlflow.get_experiment_by_name(self.experiment_name) is None:
|
129
|
+
# mlflow.create_experiment(self.experiment_name)
|
130
|
+
# if self.experiment_tags:
|
131
|
+
# mlflow.set_experiment_tags(self.experiment_tags)
|
132
|
+
# mlflow.set_experiment(self.experiment_name)
|
133
|
+
|
134
|
+
# multi_modal_runner = MultiModalRunner(
|
135
|
+
# library=self.library,
|
136
|
+
# multi_modal=self.multi_modal,
|
137
|
+
# default_model=self.default_model,
|
138
|
+
# use_mlflow=True,
|
139
|
+
# )
|
140
|
+
# multi_modal_runner.set_data(self.items, self.train_idx_arr, self.val_idx_arr, self.y_test_label)
|
141
|
+
# result = multi_modal_runner.predict()
|
142
|
+
# return result
|
143
|
+
|
144
|
+
def run(self):
|
145
|
+
if self._tracking_service is None:
|
146
|
+
logging.warning("Tracking service is not initialized.")
|
147
|
+
else:
|
148
|
+
self._tracking_service.setup()
|
149
|
+
|
150
|
+
y_test_label = self.items[0][4]
|
151
|
+
multi_modal_runner = MultiModalRunner(
|
152
|
+
library=self.library, multi_modal=self.multi_modal, default_model=self.default_model, tracking_service=self._tracking_service)
|
153
|
+
|
154
|
+
multi_modal_runner.set_data(
|
155
|
+
self.items, self.train_idx_arr, self.val_idx_arr, y_test_label)
|
156
|
+
combinations = self.combinations if self.combinations is not None else []
|
157
|
+
result = multi_modal_runner.predict(combinations)
|
158
|
+
return result
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: ddi_fw
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.227
|
4
4
|
Summary: Do not use :)
|
5
5
|
Author-email: Kıvanç Bayraktar <bayraktarkivanc@gmail.com>
|
6
6
|
Maintainer-email: Kıvanç Bayraktar <bayraktarkivanc@gmail.com>
|
@@ -33,7 +33,7 @@ Requires-Dist: tokenizers>=0.19.1; extra == "llm"
|
|
33
33
|
Requires-Dist: openai>=1.52.2; extra == "llm"
|
34
34
|
Requires-Dist: langchain>=0.3.4; extra == "llm"
|
35
35
|
Requires-Dist: langchain_community>0.3.16; extra == "llm"
|
36
|
-
Requires-Dist: chromadb
|
36
|
+
Requires-Dist: chromadb<=1.0.4,>=0.7.0; extra == "llm"
|
37
37
|
Requires-Dist: nltk>=3.8.1; extra == "llm"
|
38
38
|
Provides-Extra: ml
|
39
39
|
Requires-Dist: scikit-learn<=1.6.1,>=1.5.2; extra == "ml"
|
@@ -1,5 +1,5 @@
|
|
1
1
|
ddi_fw/datasets/__init__.py,sha256=_I3iDHARwzmg7_EL5XKtB_TgG1yAkLSOVTujLL9Wz9Q,280
|
2
|
-
ddi_fw/datasets/core.py,sha256=
|
2
|
+
ddi_fw/datasets/core.py,sha256=p-e3wP5C_SCh0fMXioUHUXKvLVtyCrsQCFvKRnH4fjs,17008
|
3
3
|
ddi_fw/datasets/dataset_splitter.py,sha256=8H8uZTAf8N9LUZeSeHOMawtJFJhnDgUUqFcnl7dquBQ,1672
|
4
4
|
ddi_fw/datasets/db_utils.py,sha256=xRj28U_uXTRPHcz3yIICczFUHXUPiAOZtAj5BM6kH44,6465
|
5
5
|
ddi_fw/datasets/setup_._py,sha256=khYVJuW5PlOY_i_A16F3UbSZ6s6o_ljw33Byw3C-A8E,1047
|
@@ -77,16 +77,16 @@ ddi_fw/ml/evaluation_helper.py,sha256=2-7CLSgGTqLEk4HkgCVIOt-GxfLAn6SBozJghAtHb5
|
|
77
77
|
ddi_fw/ml/ml_helper.py,sha256=EXMmaSoSmP4RR1zyb1crBE8wwfJohHwWvOhelddtMhI,7945
|
78
78
|
ddi_fw/ml/model_wrapper.py,sha256=38uBdHI4H_sjDKPWuhGXovUy_L1tpSNm5tEqCtwmlpY,973
|
79
79
|
ddi_fw/ml/pytorch_wrapper.py,sha256=pe6UsjP2XeTgLxDnIUiodoyhJTGCxV27wD4Cjxysu2Q,8553
|
80
|
-
ddi_fw/ml/tensorflow_wrapper.py,sha256=
|
80
|
+
ddi_fw/ml/tensorflow_wrapper.py,sha256=_mOXMpIkXx7lJySC2wtCDIDhSdtA8bQVEjKwJ5NQ7Io,16782
|
81
81
|
ddi_fw/ml/tracking_service.py,sha256=eHWFI3lyQX_xM16CRekgITwldHj2RBMYl5XG8lD8Zks,7508
|
82
82
|
ddi_fw/ner/__init__.py,sha256=JwhGXrepomxPSsGsg2b_xPRC72AjvxOIn2CW5Mvscn0,26
|
83
83
|
ddi_fw/ner/mmlrestclient.py,sha256=NZta7m2Qm6I_qtVguMZhqtAUjVBmmXn0-TMnsNp0jpg,6859
|
84
84
|
ddi_fw/ner/ner.py,sha256=FHyyX53Xwpdw8Hec261dyN88yD7Z9LmJua2mIrQLguI,17967
|
85
85
|
ddi_fw/pipeline/__init__.py,sha256=tKDM_rW4vPjlYTeOkNgi9PujDzb4e9O3LK1w5wqnebw,212
|
86
86
|
ddi_fw/pipeline/multi_modal_combination_strategy.py,sha256=JSyuP71b1I1yuk0s2ecCJZTtCED85jBtkpwTUxibJvI,1706
|
87
|
-
ddi_fw/pipeline/multi_pipeline.py,sha256=
|
87
|
+
ddi_fw/pipeline/multi_pipeline.py,sha256=EjJnA3Vzd-WeEvUBaA2LDOy_iQ5-2eW2VhtxvvxDPfQ,9857
|
88
88
|
ddi_fw/pipeline/multi_pipeline_org.py,sha256=AbErwu05-3YIPnCcXRsj-jxPJG8HG2H7cMZlGjzaYa8,9037
|
89
|
-
ddi_fw/pipeline/ner_pipeline.py,sha256=
|
89
|
+
ddi_fw/pipeline/ner_pipeline.py,sha256=IVtmlBhQ73FeR0b26U33yWlNVwqiEqdvBAseTz6CVsk,6954
|
90
90
|
ddi_fw/pipeline/pipeline.py,sha256=q1kMkW9-fOlrA4BOGUku40U_PuEYfcbtH2EvlRM4uTM,6243
|
91
91
|
ddi_fw/utils/__init__.py,sha256=WNxkQXk-694roG50D355TGLXstfdWVb_tUyr-PM-8rg,537
|
92
92
|
ddi_fw/utils/categorical_data_encoding_checker.py,sha256=T1X70Rh4atucAuqyUZmz-iFULllY9dY0NRyV9-jTjJ0,3438
|
@@ -101,7 +101,7 @@ ddi_fw/utils/zip_helper.py,sha256=YRZA4tKZVBJwGQM0_WK6L-y5MoqkKoC-nXuuHK6CU9I,55
|
|
101
101
|
ddi_fw/vectorization/__init__.py,sha256=LcJOpLVoLvHPDw9phGFlUQGeNcST_zKV-Oi1Pm5h_nE,110
|
102
102
|
ddi_fw/vectorization/feature_vector_generation.py,sha256=EBf-XAiwQwr68az91erEYNegfeqssBR29kVgrliIyac,4765
|
103
103
|
ddi_fw/vectorization/idf_helper.py,sha256=_Gd1dtDSLaw8o-o0JugzSKMt9FpeXewTh4wGEaUd4VQ,2571
|
104
|
-
ddi_fw-0.0.
|
105
|
-
ddi_fw-0.0.
|
106
|
-
ddi_fw-0.0.
|
107
|
-
ddi_fw-0.0.
|
104
|
+
ddi_fw-0.0.227.dist-info/METADATA,sha256=yVVPcTBE4VRLFs4K7jWuOQWoLe-B_i8c8BV1YJCjI7U,2632
|
105
|
+
ddi_fw-0.0.227.dist-info/WHEEL,sha256=lTU6B6eIfYoiQJTZNc-fyaR6BpL6ehTzU3xGYxn2n8k,91
|
106
|
+
ddi_fw-0.0.227.dist-info/top_level.txt,sha256=PMwHICFZTZtcpzQNPV4UQnfNXYIeLR_Ste-Wfc1h810,7
|
107
|
+
ddi_fw-0.0.227.dist-info/RECORD,,
|
File without changes
|