ddi-fw 0.0.225__py3-none-any.whl → 0.0.227__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ddi_fw/datasets/core.py CHANGED
@@ -4,7 +4,7 @@ import glob
4
4
  import logging
5
5
  from typing import Any, Dict, List, Optional, Type
6
6
  import chromadb
7
- from chromadb.api.types import IncludeEnum
7
+ # from chromadb.api.types import IncludeEnum
8
8
  import numpy as np
9
9
  import pandas as pd
10
10
  from pydantic import BaseModel, Field, computed_field
@@ -307,7 +307,8 @@ class TextDatasetMixin(BaseModel):
307
307
  vector_db = chromadb.PersistentClient(
308
308
  path=vector_db_persist_directory)
309
309
  collection = vector_db.get_collection(vector_db_collection_name)
310
- include = [IncludeEnum.embeddings, IncludeEnum.metadatas]
310
+ # include = [IncludeEnum.embeddings, IncludeEnum.metadatas]
311
+ include: chromadb.Include = ["embeddings","metadatas"]
311
312
  dictionary: chromadb.GetResult
312
313
  # Fetch the embeddings and metadata
313
314
  if column == None:
@@ -102,9 +102,18 @@ class TFModelWrapper(ModelWrapper):
102
102
  history = model.fit(
103
103
  train_dataset,
104
104
  epochs=self.epochs,
105
- # validation_data=val_dataset,
105
+ validation_data=val_dataset,
106
106
  callbacks=callbacks
107
107
  )
108
+
109
+ # Check if early stopping was applied
110
+ if early_stopping.stopped_epoch > 0:
111
+ print(f"Early stopping was applied at epoch {early_stopping.stopped_epoch}.")
112
+ else:
113
+ print("Early stopping was not applied.")
114
+ if self.tracking_service:
115
+ self.tracking_service.log_param("early_stopping_applied", early_stopping.stopped_epoch > 0)
116
+ self.tracking_service.log_param("early_stopping_epoch", early_stopping.stopped_epoch)
108
117
  # ex
109
118
  # history = model.fit(
110
119
  # X_train, y_train,
@@ -174,6 +183,7 @@ class TFModelWrapper(ModelWrapper):
174
183
  if models_val_acc == {}:
175
184
  return model, None
176
185
  best_model_key = max(models_val_acc, key=lambda k: models_val_acc[k])
186
+ print("best model key: ", best_model_key)
177
187
  # best_model_key = max(models_val_acc, key=models_val_acc.get)
178
188
  best_model = models[best_model_key]
179
189
  return best_model, best_model_key
@@ -191,15 +191,15 @@ class MultiPipeline():
191
191
  elif type== "ner_search":
192
192
  pipeline = NerParameterSearch(
193
193
  library=library,
194
+ tracking_library=tracking_library,
195
+ tracking_params=tracking_params,
194
196
  experiment_name=experiment_name,
195
197
  experiment_description=experiment_description,
196
- experiment_tags=experiment_tags,
197
- tracking_uri=tracking_uri,
198
198
  dataset_type=dataset_type,
199
+ dataset_additional_config=additional_config,
199
200
  umls_code_types = None,
200
201
  text_types = None,
201
- columns=['tui', 'cui', 'entities'],
202
- ner_data_file=ner_data_file,
202
+ columns=columns,
203
203
  multi_modal= multi_modal
204
204
  )
205
205
 
@@ -10,19 +10,21 @@ from ddi_fw.vectorization.idf_helper import IDF
10
10
  from ddi_fw.ner.ner import CTakesNER
11
11
  from ddi_fw.ml.ml_helper import MultiModalRunner
12
12
  from ddi_fw.utils.enums import DrugBankTextDataTypes, UMLSCodeTypes
13
+ import logging
13
14
 
14
-
15
+
15
16
  class NerParameterSearch(BaseModel):
16
17
  library: str
17
18
  default_model: Optional[Any] = None
18
19
  multi_modal: Optional[Any] = None
19
20
  experiment_name: str
20
21
  experiment_description: Optional[str] = None
21
- experiment_tags: Optional[Dict[str, Any]] = None
22
- tracking_uri: str
22
+ tracking_library: str
23
+ tracking_params: Optional[Dict[str, Any]] = None
24
+ dataset_type: Type[BaseDataset]
25
+ dataset_additional_config: Optional[Dict[str, Any]] = None
23
26
  dataset_type: Type[BaseDataset]
24
27
  dataset_splitter_type: Type[DatasetSplitter] = DatasetSplitter
25
- ner_data_file: Optional[str] = None
26
28
  columns: List[str] = Field(default_factory=list)
27
29
  umls_code_types: Optional[List[UMLSCodeTypes]] = None
28
30
  text_types: Optional[List[DrugBankTextDataTypes]] = None
@@ -33,7 +35,7 @@ class NerParameterSearch(BaseModel):
33
35
  # Internal fields (not part of the input)
34
36
  datasets: Dict[str, Any] = Field(default_factory=dict, exclude=True)
35
37
  items: List[Any] = Field(default_factory=list, exclude=True)
36
- ner_df: Optional[Any] = Field(default=None, exclude=True)
38
+ # ner_df: Optional[Any] = Field(default=None, exclude=True)
37
39
  train_idx_arr: Optional[List[np.ndarray]] = Field(default=None, exclude=True)
38
40
  val_idx_arr: Optional[List[np.ndarray]] = Field(default=None, exclude=True)
39
41
  y_test_label: Optional[np.ndarray] = Field(default=None, exclude=True)
@@ -64,12 +66,17 @@ class NerParameterSearch(BaseModel):
64
66
  raise TypeError("self.dataset_type must be a class, not an instance")
65
67
 
66
68
  # Load NER data
67
- if self.ner_data_file:
68
- self.ner_df = CTakesNER(df=None).load(filename=self.ner_data_file)
69
+ ner_data_file = (
70
+ self.dataset_additional_config.get("ner", {}).get("data_file")
71
+ if self.dataset_additional_config else None
72
+ )
73
+
74
+ if ner_data_file:
75
+ ner_df = CTakesNER(df=None).load(filename=ner_data_file)
69
76
 
70
77
  # Initialize thresholds if not provided
71
78
  if not self.min_threshold_dict or not self.max_threshold_dict:
72
- idf = IDF(self.ner_df, self.columns)
79
+ idf = IDF(ner_df, self.columns)
73
80
  idf.calculate()
74
81
  df = idf.to_dataframe()
75
82
  self.min_threshold_dict = {key: np.floor(df.describe()[key]["min"]) for key in df.describe().keys()}
@@ -85,6 +92,8 @@ class NerParameterSearch(BaseModel):
85
92
  "cui_threshold": 0,
86
93
  "entities_threshold": 0,
87
94
  }
95
+ if self.dataset_additional_config:
96
+ kwargs["additional_config"]= self.dataset_additional_config
88
97
 
89
98
  for threshold in np.arange(min_threshold, max_threshold, self.increase_step):
90
99
  if column.startswith("tui"):
@@ -93,10 +102,9 @@ class NerParameterSearch(BaseModel):
93
102
  kwargs["cui_threshold"] = threshold
94
103
  if column.startswith("entities"):
95
104
  kwargs["entities_threshold"] = threshold
96
-
105
+
97
106
  dataset = self.dataset_type(
98
107
  columns=[column],
99
- ner_df=self.ner_df,
100
108
  dataset_splitter_type=self.dataset_splitter_type,
101
109
  **kwargs,
102
110
  )
@@ -113,22 +121,38 @@ class NerParameterSearch(BaseModel):
113
121
  self.train_idx_arr = dataset.train_idx_arr
114
122
  self.val_idx_arr = dataset.val_idx_arr
115
123
 
116
- def run(self):
117
- """Run the parameter search."""
118
- mlflow.set_tracking_uri(self.tracking_uri)
119
-
120
- if mlflow.get_experiment_by_name(self.experiment_name) is None:
121
- mlflow.create_experiment(self.experiment_name)
122
- if self.experiment_tags:
123
- mlflow.set_experiment_tags(self.experiment_tags)
124
- mlflow.set_experiment(self.experiment_name)
125
-
126
- multi_modal_runner = MultiModalRunner(
127
- library=self.library,
128
- multi_modal=self.multi_modal,
129
- default_model=self.default_model,
130
- use_mlflow=True,
131
- )
132
- multi_modal_runner.set_data(self.items, self.train_idx_arr, self.val_idx_arr, self.y_test_label)
133
- result = multi_modal_runner.predict()
134
- return result
124
+ # def run(self):
125
+ # """Run the parameter search."""
126
+ # mlflow.set_tracking_uri(self.tracking_uri)
127
+
128
+ # if mlflow.get_experiment_by_name(self.experiment_name) is None:
129
+ # mlflow.create_experiment(self.experiment_name)
130
+ # if self.experiment_tags:
131
+ # mlflow.set_experiment_tags(self.experiment_tags)
132
+ # mlflow.set_experiment(self.experiment_name)
133
+
134
+ # multi_modal_runner = MultiModalRunner(
135
+ # library=self.library,
136
+ # multi_modal=self.multi_modal,
137
+ # default_model=self.default_model,
138
+ # use_mlflow=True,
139
+ # )
140
+ # multi_modal_runner.set_data(self.items, self.train_idx_arr, self.val_idx_arr, self.y_test_label)
141
+ # result = multi_modal_runner.predict()
142
+ # return result
143
+
144
+ def run(self):
145
+ if self._tracking_service is None:
146
+ logging.warning("Tracking service is not initialized.")
147
+ else:
148
+ self._tracking_service.setup()
149
+
150
+ y_test_label = self.items[0][4]
151
+ multi_modal_runner = MultiModalRunner(
152
+ library=self.library, multi_modal=self.multi_modal, default_model=self.default_model, tracking_service=self._tracking_service)
153
+
154
+ multi_modal_runner.set_data(
155
+ self.items, self.train_idx_arr, self.val_idx_arr, y_test_label)
156
+ combinations = self.combinations if self.combinations is not None else []
157
+ result = multi_modal_runner.predict(combinations)
158
+ return result
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ddi_fw
3
- Version: 0.0.225
3
+ Version: 0.0.227
4
4
  Summary: Do not use :)
5
5
  Author-email: Kıvanç Bayraktar <bayraktarkivanc@gmail.com>
6
6
  Maintainer-email: Kıvanç Bayraktar <bayraktarkivanc@gmail.com>
@@ -33,7 +33,7 @@ Requires-Dist: tokenizers>=0.19.1; extra == "llm"
33
33
  Requires-Dist: openai>=1.52.2; extra == "llm"
34
34
  Requires-Dist: langchain>=0.3.4; extra == "llm"
35
35
  Requires-Dist: langchain_community>0.3.16; extra == "llm"
36
- Requires-Dist: chromadb<1.0.0,>=0.6.0; extra == "llm"
36
+ Requires-Dist: chromadb<=1.0.4,>=0.7.0; extra == "llm"
37
37
  Requires-Dist: nltk>=3.8.1; extra == "llm"
38
38
  Provides-Extra: ml
39
39
  Requires-Dist: scikit-learn<=1.6.1,>=1.5.2; extra == "ml"
@@ -1,5 +1,5 @@
1
1
  ddi_fw/datasets/__init__.py,sha256=_I3iDHARwzmg7_EL5XKtB_TgG1yAkLSOVTujLL9Wz9Q,280
2
- ddi_fw/datasets/core.py,sha256=PX6MX4hmeYxIWAKAx7NnJr1fpzR11xA8g8vAjYcQNN8,16936
2
+ ddi_fw/datasets/core.py,sha256=p-e3wP5C_SCh0fMXioUHUXKvLVtyCrsQCFvKRnH4fjs,17008
3
3
  ddi_fw/datasets/dataset_splitter.py,sha256=8H8uZTAf8N9LUZeSeHOMawtJFJhnDgUUqFcnl7dquBQ,1672
4
4
  ddi_fw/datasets/db_utils.py,sha256=xRj28U_uXTRPHcz3yIICczFUHXUPiAOZtAj5BM6kH44,6465
5
5
  ddi_fw/datasets/setup_._py,sha256=khYVJuW5PlOY_i_A16F3UbSZ6s6o_ljw33Byw3C-A8E,1047
@@ -77,16 +77,16 @@ ddi_fw/ml/evaluation_helper.py,sha256=2-7CLSgGTqLEk4HkgCVIOt-GxfLAn6SBozJghAtHb5
77
77
  ddi_fw/ml/ml_helper.py,sha256=EXMmaSoSmP4RR1zyb1crBE8wwfJohHwWvOhelddtMhI,7945
78
78
  ddi_fw/ml/model_wrapper.py,sha256=38uBdHI4H_sjDKPWuhGXovUy_L1tpSNm5tEqCtwmlpY,973
79
79
  ddi_fw/ml/pytorch_wrapper.py,sha256=pe6UsjP2XeTgLxDnIUiodoyhJTGCxV27wD4Cjxysu2Q,8553
80
- ddi_fw/ml/tensorflow_wrapper.py,sha256=8hQitM6r0jVkSi4P5O4qjGYuJFT326JcojCrifVEF_M,16227
80
+ ddi_fw/ml/tensorflow_wrapper.py,sha256=_mOXMpIkXx7lJySC2wtCDIDhSdtA8bQVEjKwJ5NQ7Io,16782
81
81
  ddi_fw/ml/tracking_service.py,sha256=eHWFI3lyQX_xM16CRekgITwldHj2RBMYl5XG8lD8Zks,7508
82
82
  ddi_fw/ner/__init__.py,sha256=JwhGXrepomxPSsGsg2b_xPRC72AjvxOIn2CW5Mvscn0,26
83
83
  ddi_fw/ner/mmlrestclient.py,sha256=NZta7m2Qm6I_qtVguMZhqtAUjVBmmXn0-TMnsNp0jpg,6859
84
84
  ddi_fw/ner/ner.py,sha256=FHyyX53Xwpdw8Hec261dyN88yD7Z9LmJua2mIrQLguI,17967
85
85
  ddi_fw/pipeline/__init__.py,sha256=tKDM_rW4vPjlYTeOkNgi9PujDzb4e9O3LK1w5wqnebw,212
86
86
  ddi_fw/pipeline/multi_modal_combination_strategy.py,sha256=JSyuP71b1I1yuk0s2ecCJZTtCED85jBtkpwTUxibJvI,1706
87
- ddi_fw/pipeline/multi_pipeline.py,sha256=npJUXYT31fxD6kpJKSeixjbH5jNfPUwIVG7lRdBszRg,9852
87
+ ddi_fw/pipeline/multi_pipeline.py,sha256=EjJnA3Vzd-WeEvUBaA2LDOy_iQ5-2eW2VhtxvvxDPfQ,9857
88
88
  ddi_fw/pipeline/multi_pipeline_org.py,sha256=AbErwu05-3YIPnCcXRsj-jxPJG8HG2H7cMZlGjzaYa8,9037
89
- ddi_fw/pipeline/ner_pipeline.py,sha256=yp-Met2794EKcgr8_3gqt03l4v2efOdaZuAcIXTubvQ,5780
89
+ ddi_fw/pipeline/ner_pipeline.py,sha256=IVtmlBhQ73FeR0b26U33yWlNVwqiEqdvBAseTz6CVsk,6954
90
90
  ddi_fw/pipeline/pipeline.py,sha256=q1kMkW9-fOlrA4BOGUku40U_PuEYfcbtH2EvlRM4uTM,6243
91
91
  ddi_fw/utils/__init__.py,sha256=WNxkQXk-694roG50D355TGLXstfdWVb_tUyr-PM-8rg,537
92
92
  ddi_fw/utils/categorical_data_encoding_checker.py,sha256=T1X70Rh4atucAuqyUZmz-iFULllY9dY0NRyV9-jTjJ0,3438
@@ -101,7 +101,7 @@ ddi_fw/utils/zip_helper.py,sha256=YRZA4tKZVBJwGQM0_WK6L-y5MoqkKoC-nXuuHK6CU9I,55
101
101
  ddi_fw/vectorization/__init__.py,sha256=LcJOpLVoLvHPDw9phGFlUQGeNcST_zKV-Oi1Pm5h_nE,110
102
102
  ddi_fw/vectorization/feature_vector_generation.py,sha256=EBf-XAiwQwr68az91erEYNegfeqssBR29kVgrliIyac,4765
103
103
  ddi_fw/vectorization/idf_helper.py,sha256=_Gd1dtDSLaw8o-o0JugzSKMt9FpeXewTh4wGEaUd4VQ,2571
104
- ddi_fw-0.0.225.dist-info/METADATA,sha256=Oco3hzLa5jxJN2MLpaWtqzDeYALwV5g5pR-zK_GU4aE,2631
105
- ddi_fw-0.0.225.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
106
- ddi_fw-0.0.225.dist-info/top_level.txt,sha256=PMwHICFZTZtcpzQNPV4UQnfNXYIeLR_Ste-Wfc1h810,7
107
- ddi_fw-0.0.225.dist-info/RECORD,,
104
+ ddi_fw-0.0.227.dist-info/METADATA,sha256=yVVPcTBE4VRLFs4K7jWuOQWoLe-B_i8c8BV1YJCjI7U,2632
105
+ ddi_fw-0.0.227.dist-info/WHEEL,sha256=lTU6B6eIfYoiQJTZNc-fyaR6BpL6ehTzU3xGYxn2n8k,91
106
+ ddi_fw-0.0.227.dist-info/top_level.txt,sha256=PMwHICFZTZtcpzQNPV4UQnfNXYIeLR_Ste-Wfc1h810,7
107
+ ddi_fw-0.0.227.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (78.1.0)
2
+ Generator: setuptools (78.1.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5