ddi-fw 0.0.231__py3-none-any.whl → 0.0.232__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ddi_fw/pipeline/ner_pipeline.py +38 -19
- {ddi_fw-0.0.231.dist-info → ddi_fw-0.0.232.dist-info}/METADATA +1 -1
- {ddi_fw-0.0.231.dist-info → ddi_fw-0.0.232.dist-info}/RECORD +5 -5
- {ddi_fw-0.0.231.dist-info → ddi_fw-0.0.232.dist-info}/WHEEL +0 -0
- {ddi_fw-0.0.231.dist-info → ddi_fw-0.0.232.dist-info}/top_level.txt +0 -0
ddi_fw/pipeline/ner_pipeline.py
CHANGED
@@ -6,6 +6,7 @@ import mlflow
|
|
6
6
|
from pydantic import BaseModel, Field, model_validator, root_validator, validator
|
7
7
|
from ddi_fw.datasets.core import BaseDataset
|
8
8
|
from ddi_fw.datasets.dataset_splitter import DatasetSplitter
|
9
|
+
from ddi_fw.ml.tracking_service import TrackingService
|
9
10
|
from ddi_fw.vectorization.idf_helper import IDF
|
10
11
|
from ddi_fw.ner.ner import CTakesNER
|
11
12
|
from ddi_fw.ml.ml_helper import MultiModalRunner
|
@@ -33,6 +34,7 @@ class NerParameterSearch(BaseModel):
|
|
33
34
|
increase_step: float = 0.5
|
34
35
|
|
35
36
|
# Internal fields (not part of the input)
|
37
|
+
_tracking_service: TrackingService | None = None
|
36
38
|
datasets: Dict[str, Any] = Field(default_factory=dict, exclude=True)
|
37
39
|
items: List[Any] = Field(default_factory=list, exclude=True)
|
38
40
|
# ner_df: Optional[Any] = Field(default=None, exclude=True)
|
@@ -42,6 +44,10 @@ class NerParameterSearch(BaseModel):
|
|
42
44
|
|
43
45
|
class Config:
|
44
46
|
arbitrary_types_allowed = True
|
47
|
+
|
48
|
+
@property
|
49
|
+
def tracking_service(self) -> TrackingService | None:
|
50
|
+
return self._tracking_service
|
45
51
|
|
46
52
|
# @root_validator(pre=True)
|
47
53
|
@model_validator(mode="before")
|
@@ -61,6 +67,9 @@ class NerParameterSearch(BaseModel):
|
|
61
67
|
return values
|
62
68
|
|
63
69
|
def build(self):
|
70
|
+
self._tracking_service = TrackingService(self.experiment_name,
|
71
|
+
backend=self.tracking_library, tracking_params=self.tracking_params)
|
72
|
+
|
64
73
|
"""Build the datasets and items for the parameter search."""
|
65
74
|
if not isinstance(self.dataset_type, type):
|
66
75
|
raise TypeError("self.dataset_type must be a class, not an instance")
|
@@ -124,10 +133,20 @@ class NerParameterSearch(BaseModel):
|
|
124
133
|
self.datasets[item[0]] = dataset
|
125
134
|
|
126
135
|
self.items.extend(group_items)
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
136
|
+
|
137
|
+
# Set if y_test_label is None
|
138
|
+
# This ensures that y_test_label is set only once for the first dataset
|
139
|
+
if self.y_test_label is None:
|
140
|
+
self.y_test_label = self.items[0][4]
|
141
|
+
self.train_idx_arr = dataset.train_idx_arr
|
142
|
+
self.val_idx_arr = dataset.val_idx_arr
|
143
|
+
|
144
|
+
# Clear memory for the current dataset and items
|
145
|
+
del dataset
|
146
|
+
del group_items
|
147
|
+
import gc
|
148
|
+
gc.collect()
|
149
|
+
|
131
150
|
|
132
151
|
# def run(self):
|
133
152
|
# """Run the parameter search."""
|
@@ -149,18 +168,18 @@ class NerParameterSearch(BaseModel):
|
|
149
168
|
# result = multi_modal_runner.predict()
|
150
169
|
# return result
|
151
170
|
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
171
|
+
def run(self):
|
172
|
+
if self._tracking_service is None:
|
173
|
+
logging.warning("Tracking service is not initialized.")
|
174
|
+
else:
|
175
|
+
self._tracking_service.setup()
|
176
|
+
|
177
|
+
y_test_label = self.items[0][4]
|
178
|
+
multi_modal_runner = MultiModalRunner(
|
179
|
+
library=self.library, multi_modal=self.multi_modal, default_model=self.default_model, tracking_service=self._tracking_service)
|
180
|
+
|
181
|
+
multi_modal_runner.set_data(
|
182
|
+
self.items, self.train_idx_arr, self.val_idx_arr, y_test_label)
|
183
|
+
# combinations = self.combinations if self.combinations is not None else []
|
184
|
+
result = multi_modal_runner.predict()
|
185
|
+
return result
|
@@ -86,7 +86,7 @@ ddi_fw/pipeline/__init__.py,sha256=tKDM_rW4vPjlYTeOkNgi9PujDzb4e9O3LK1w5wqnebw,2
|
|
86
86
|
ddi_fw/pipeline/multi_modal_combination_strategy.py,sha256=JSyuP71b1I1yuk0s2ecCJZTtCED85jBtkpwTUxibJvI,1706
|
87
87
|
ddi_fw/pipeline/multi_pipeline.py,sha256=EjJnA3Vzd-WeEvUBaA2LDOy_iQ5-2eW2VhtxvvxDPfQ,9857
|
88
88
|
ddi_fw/pipeline/multi_pipeline_org.py,sha256=AbErwu05-3YIPnCcXRsj-jxPJG8HG2H7cMZlGjzaYa8,9037
|
89
|
-
ddi_fw/pipeline/ner_pipeline.py,sha256=
|
89
|
+
ddi_fw/pipeline/ner_pipeline.py,sha256=BycxZvI7JRJ3s3HhYAgOxG2_lqrVnhv7ECOWSgVQhz4,8186
|
90
90
|
ddi_fw/pipeline/pipeline.py,sha256=q1kMkW9-fOlrA4BOGUku40U_PuEYfcbtH2EvlRM4uTM,6243
|
91
91
|
ddi_fw/utils/__init__.py,sha256=WNxkQXk-694roG50D355TGLXstfdWVb_tUyr-PM-8rg,537
|
92
92
|
ddi_fw/utils/categorical_data_encoding_checker.py,sha256=T1X70Rh4atucAuqyUZmz-iFULllY9dY0NRyV9-jTjJ0,3438
|
@@ -101,7 +101,7 @@ ddi_fw/utils/zip_helper.py,sha256=YRZA4tKZVBJwGQM0_WK6L-y5MoqkKoC-nXuuHK6CU9I,55
|
|
101
101
|
ddi_fw/vectorization/__init__.py,sha256=LcJOpLVoLvHPDw9phGFlUQGeNcST_zKV-Oi1Pm5h_nE,110
|
102
102
|
ddi_fw/vectorization/feature_vector_generation.py,sha256=93G3QM28uoNlvlVz_BhV6ARxldpogiNJStxHdsgqTbU,6026
|
103
103
|
ddi_fw/vectorization/idf_helper.py,sha256=_Gd1dtDSLaw8o-o0JugzSKMt9FpeXewTh4wGEaUd4VQ,2571
|
104
|
-
ddi_fw-0.0.
|
105
|
-
ddi_fw-0.0.
|
106
|
-
ddi_fw-0.0.
|
107
|
-
ddi_fw-0.0.
|
104
|
+
ddi_fw-0.0.232.dist-info/METADATA,sha256=CBSE9xsWEc0vxlCxw9NCLxTJWXmXoFTegN5wgrHVGvA,2632
|
105
|
+
ddi_fw-0.0.232.dist-info/WHEEL,sha256=pxyMxgL8-pra_rKaQ4drOZAegBVuX-G_4nRHjjgWbmo,91
|
106
|
+
ddi_fw-0.0.232.dist-info/top_level.txt,sha256=PMwHICFZTZtcpzQNPV4UQnfNXYIeLR_Ste-Wfc1h810,7
|
107
|
+
ddi_fw-0.0.232.dist-info/RECORD,,
|
File without changes
|
File without changes
|