ddi-fw 0.0.231__py3-none-any.whl → 0.0.232__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -6,6 +6,7 @@ import mlflow
6
6
  from pydantic import BaseModel, Field, model_validator, root_validator, validator
7
7
  from ddi_fw.datasets.core import BaseDataset
8
8
  from ddi_fw.datasets.dataset_splitter import DatasetSplitter
9
+ from ddi_fw.ml.tracking_service import TrackingService
9
10
  from ddi_fw.vectorization.idf_helper import IDF
10
11
  from ddi_fw.ner.ner import CTakesNER
11
12
  from ddi_fw.ml.ml_helper import MultiModalRunner
@@ -33,6 +34,7 @@ class NerParameterSearch(BaseModel):
33
34
  increase_step: float = 0.5
34
35
 
35
36
  # Internal fields (not part of the input)
37
+ _tracking_service: TrackingService | None = None
36
38
  datasets: Dict[str, Any] = Field(default_factory=dict, exclude=True)
37
39
  items: List[Any] = Field(default_factory=list, exclude=True)
38
40
  # ner_df: Optional[Any] = Field(default=None, exclude=True)
@@ -42,6 +44,10 @@ class NerParameterSearch(BaseModel):
42
44
 
43
45
  class Config:
44
46
  arbitrary_types_allowed = True
47
+
48
+ @property
49
+ def tracking_service(self) -> TrackingService | None:
50
+ return self._tracking_service
45
51
 
46
52
  # @root_validator(pre=True)
47
53
  @model_validator(mode="before")
@@ -61,6 +67,9 @@ class NerParameterSearch(BaseModel):
61
67
  return values
62
68
 
63
69
  def build(self):
70
+ self._tracking_service = TrackingService(self.experiment_name,
71
+ backend=self.tracking_library, tracking_params=self.tracking_params)
72
+
64
73
  """Build the datasets and items for the parameter search."""
65
74
  if not isinstance(self.dataset_type, type):
66
75
  raise TypeError("self.dataset_type must be a class, not an instance")
@@ -124,10 +133,20 @@ class NerParameterSearch(BaseModel):
124
133
  self.datasets[item[0]] = dataset
125
134
 
126
135
  self.items.extend(group_items)
127
-
128
- self.y_test_label = self.items[0][4]
129
- self.train_idx_arr = dataset.train_idx_arr
130
- self.val_idx_arr = dataset.val_idx_arr
136
+
137
+ # Set if y_test_label is None
138
+ # This ensures that y_test_label is set only once for the first dataset
139
+ if self.y_test_label is None:
140
+ self.y_test_label = self.items[0][4]
141
+ self.train_idx_arr = dataset.train_idx_arr
142
+ self.val_idx_arr = dataset.val_idx_arr
143
+
144
+ # Clear memory for the current dataset and items
145
+ del dataset
146
+ del group_items
147
+ import gc
148
+ gc.collect()
149
+
131
150
 
132
151
  # def run(self):
133
152
  # """Run the parameter search."""
@@ -149,18 +168,18 @@ class NerParameterSearch(BaseModel):
149
168
  # result = multi_modal_runner.predict()
150
169
  # return result
151
170
 
152
- def run(self):
153
- if self._tracking_service is None:
154
- logging.warning("Tracking service is not initialized.")
155
- else:
156
- self._tracking_service.setup()
157
-
158
- y_test_label = self.items[0][4]
159
- multi_modal_runner = MultiModalRunner(
160
- library=self.library, multi_modal=self.multi_modal, default_model=self.default_model, tracking_service=self._tracking_service)
161
-
162
- multi_modal_runner.set_data(
163
- self.items, self.train_idx_arr, self.val_idx_arr, y_test_label)
164
- combinations = self.combinations if self.combinations is not None else []
165
- result = multi_modal_runner.predict(combinations)
166
- return result
171
+ def run(self):
172
+ if self._tracking_service is None:
173
+ logging.warning("Tracking service is not initialized.")
174
+ else:
175
+ self._tracking_service.setup()
176
+
177
+ y_test_label = self.items[0][4]
178
+ multi_modal_runner = MultiModalRunner(
179
+ library=self.library, multi_modal=self.multi_modal, default_model=self.default_model, tracking_service=self._tracking_service)
180
+
181
+ multi_modal_runner.set_data(
182
+ self.items, self.train_idx_arr, self.val_idx_arr, y_test_label)
183
+ # combinations = self.combinations if self.combinations is not None else []
184
+ result = multi_modal_runner.predict()
185
+ return result
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ddi_fw
3
- Version: 0.0.231
3
+ Version: 0.0.232
4
4
  Summary: Do not use :)
5
5
  Author-email: Kıvanç Bayraktar <bayraktarkivanc@gmail.com>
6
6
  Maintainer-email: Kıvanç Bayraktar <bayraktarkivanc@gmail.com>
@@ -86,7 +86,7 @@ ddi_fw/pipeline/__init__.py,sha256=tKDM_rW4vPjlYTeOkNgi9PujDzb4e9O3LK1w5wqnebw,2
86
86
  ddi_fw/pipeline/multi_modal_combination_strategy.py,sha256=JSyuP71b1I1yuk0s2ecCJZTtCED85jBtkpwTUxibJvI,1706
87
87
  ddi_fw/pipeline/multi_pipeline.py,sha256=EjJnA3Vzd-WeEvUBaA2LDOy_iQ5-2eW2VhtxvvxDPfQ,9857
88
88
  ddi_fw/pipeline/multi_pipeline_org.py,sha256=AbErwu05-3YIPnCcXRsj-jxPJG8HG2H7cMZlGjzaYa8,9037
89
- ddi_fw/pipeline/ner_pipeline.py,sha256=EhM9A4AarpLASDMf4TDUCYFOESB-V3jEj77KskTqjXw,7368
89
+ ddi_fw/pipeline/ner_pipeline.py,sha256=BycxZvI7JRJ3s3HhYAgOxG2_lqrVnhv7ECOWSgVQhz4,8186
90
90
  ddi_fw/pipeline/pipeline.py,sha256=q1kMkW9-fOlrA4BOGUku40U_PuEYfcbtH2EvlRM4uTM,6243
91
91
  ddi_fw/utils/__init__.py,sha256=WNxkQXk-694roG50D355TGLXstfdWVb_tUyr-PM-8rg,537
92
92
  ddi_fw/utils/categorical_data_encoding_checker.py,sha256=T1X70Rh4atucAuqyUZmz-iFULllY9dY0NRyV9-jTjJ0,3438
@@ -101,7 +101,7 @@ ddi_fw/utils/zip_helper.py,sha256=YRZA4tKZVBJwGQM0_WK6L-y5MoqkKoC-nXuuHK6CU9I,55
101
101
  ddi_fw/vectorization/__init__.py,sha256=LcJOpLVoLvHPDw9phGFlUQGeNcST_zKV-Oi1Pm5h_nE,110
102
102
  ddi_fw/vectorization/feature_vector_generation.py,sha256=93G3QM28uoNlvlVz_BhV6ARxldpogiNJStxHdsgqTbU,6026
103
103
  ddi_fw/vectorization/idf_helper.py,sha256=_Gd1dtDSLaw8o-o0JugzSKMt9FpeXewTh4wGEaUd4VQ,2571
104
- ddi_fw-0.0.231.dist-info/METADATA,sha256=Cn7cy18IY6LxKcWxUf4Iad7ipvVZBQ0ZF5dwJnXwA08,2632
105
- ddi_fw-0.0.231.dist-info/WHEEL,sha256=pxyMxgL8-pra_rKaQ4drOZAegBVuX-G_4nRHjjgWbmo,91
106
- ddi_fw-0.0.231.dist-info/top_level.txt,sha256=PMwHICFZTZtcpzQNPV4UQnfNXYIeLR_Ste-Wfc1h810,7
107
- ddi_fw-0.0.231.dist-info/RECORD,,
104
+ ddi_fw-0.0.232.dist-info/METADATA,sha256=CBSE9xsWEc0vxlCxw9NCLxTJWXmXoFTegN5wgrHVGvA,2632
105
+ ddi_fw-0.0.232.dist-info/WHEEL,sha256=pxyMxgL8-pra_rKaQ4drOZAegBVuX-G_4nRHjjgWbmo,91
106
+ ddi_fw-0.0.232.dist-info/top_level.txt,sha256=PMwHICFZTZtcpzQNPV4UQnfNXYIeLR_Ste-Wfc1h810,7
107
+ ddi_fw-0.0.232.dist-info/RECORD,,