toolchemy 0.2.185__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- toolchemy/__main__.py +9 -0
- toolchemy/ai/clients/__init__.py +20 -0
- toolchemy/ai/clients/common.py +429 -0
- toolchemy/ai/clients/dummy_model_client.py +61 -0
- toolchemy/ai/clients/factory.py +37 -0
- toolchemy/ai/clients/gemini_client.py +48 -0
- toolchemy/ai/clients/ollama_client.py +58 -0
- toolchemy/ai/clients/openai_client.py +76 -0
- toolchemy/ai/clients/pricing.py +66 -0
- toolchemy/ai/clients/whisper_client.py +141 -0
- toolchemy/ai/prompter.py +124 -0
- toolchemy/ai/trackers/__init__.py +5 -0
- toolchemy/ai/trackers/common.py +216 -0
- toolchemy/ai/trackers/mlflow_tracker.py +221 -0
- toolchemy/ai/trackers/neptune_tracker.py +135 -0
- toolchemy/db/lightdb.py +260 -0
- toolchemy/utils/__init__.py +19 -0
- toolchemy/utils/at_exit_collector.py +109 -0
- toolchemy/utils/cacher/__init__.py +20 -0
- toolchemy/utils/cacher/cacher_diskcache.py +121 -0
- toolchemy/utils/cacher/cacher_pickle.py +152 -0
- toolchemy/utils/cacher/cacher_shelve.py +196 -0
- toolchemy/utils/cacher/common.py +174 -0
- toolchemy/utils/datestimes.py +77 -0
- toolchemy/utils/locations.py +111 -0
- toolchemy/utils/logger.py +76 -0
- toolchemy/utils/timer.py +23 -0
- toolchemy/utils/utils.py +168 -0
- toolchemy/vision/__init__.py +5 -0
- toolchemy/vision/caption_overlay.py +77 -0
- toolchemy/vision/image.py +89 -0
- toolchemy-0.2.185.dist-info/METADATA +25 -0
- toolchemy-0.2.185.dist-info/RECORD +36 -0
- toolchemy-0.2.185.dist-info/WHEEL +4 -0
- toolchemy-0.2.185.dist-info/entry_points.txt +3 -0
- toolchemy-0.2.185.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,221 @@
|
|
|
1
|
+
import time
|
|
2
|
+
from typing import Dict, Optional, Any
|
|
3
|
+
|
|
4
|
+
import mlflow
|
|
5
|
+
mlflow.autolog(disable=True)
|
|
6
|
+
from mlflow.client import MlflowClient
|
|
7
|
+
from mlflow.entities import RunStatus, Metric, Param
|
|
8
|
+
from mlflow.tracking.context.registry import resolve_tags
|
|
9
|
+
from mlflow.utils.mlflow_tags import MLFLOW_PARENT_RUN_ID, MLFLOW_RUN_NAME
|
|
10
|
+
from mlflow import MlflowException
|
|
11
|
+
|
|
12
|
+
from toolchemy.ai.trackers.common import TrackerBase
|
|
13
|
+
from toolchemy.utils.logger import get_logger
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class MLFlowTracker(TrackerBase):
|
|
17
|
+
def __init__(self, tracking_uri: str, experiment_name: str, with_artifact_logging=True,
|
|
18
|
+
tracking_client: MlflowClient | None = None):
|
|
19
|
+
super().__init__(experiment_name, with_artifact_logging)
|
|
20
|
+
self._tracking_client = None
|
|
21
|
+
self._active_run = None
|
|
22
|
+
self._active_run_id = None
|
|
23
|
+
self._experiment_id = None
|
|
24
|
+
self._reset_run()
|
|
25
|
+
|
|
26
|
+
if tracking_client:
|
|
27
|
+
self._tracking_client = tracking_client
|
|
28
|
+
else:
|
|
29
|
+
self._tracking_client = self._build_tracking_client(tracking_uri)
|
|
30
|
+
|
|
31
|
+
mlflow.set_tracking_uri(tracking_uri)
|
|
32
|
+
|
|
33
|
+
logger = get_logger()
|
|
34
|
+
logger.info(f"Mlflow tracker created")
|
|
35
|
+
logger.info(f"> tracking uri: {tracking_uri}")
|
|
36
|
+
logger.info(f"> artifact logging: {self._artifact_logging}")
|
|
37
|
+
|
|
38
|
+
@property
|
|
39
|
+
def run_name(self) -> str:
|
|
40
|
+
if not self._active_run:
|
|
41
|
+
raise RuntimeError("There is no active run!")
|
|
42
|
+
return self._active_run.info.run_name
|
|
43
|
+
|
|
44
|
+
def start_run(
|
|
45
|
+
self, run_id: str = None,
|
|
46
|
+
run_name: str = None,
|
|
47
|
+
parent_run_id: str = None,
|
|
48
|
+
user_specified_tags: Dict[str, str] = None
|
|
49
|
+
):
|
|
50
|
+
"""
|
|
51
|
+
Starts the run
|
|
52
|
+
|
|
53
|
+
:param run_id: If specified, get the run with the specified ID and log parameters and metrics under that run.
|
|
54
|
+
:param run_name: Name of new run. Used only when run_id is unspecified.
|
|
55
|
+
:param parent_run_id: If specified: current run will be nested into parent_run_id
|
|
56
|
+
:param user_specified_tags: dict of with custom tags
|
|
57
|
+
"""
|
|
58
|
+
if self._disabled:
|
|
59
|
+
return
|
|
60
|
+
|
|
61
|
+
if user_specified_tags is None:
|
|
62
|
+
user_specified_tags = {}
|
|
63
|
+
if parent_run_id is not None:
|
|
64
|
+
user_specified_tags[MLFLOW_PARENT_RUN_ID] = parent_run_id
|
|
65
|
+
if run_name:
|
|
66
|
+
user_specified_tags[MLFLOW_RUN_NAME] = run_name
|
|
67
|
+
|
|
68
|
+
tags = resolve_tags(user_specified_tags)
|
|
69
|
+
|
|
70
|
+
experiment = self._tracking_client.get_experiment_by_name(self._experiment_name)
|
|
71
|
+
if experiment:
|
|
72
|
+
experiment_comment_msg = "(already exists)"
|
|
73
|
+
self._logger.debug(f"Experiment '{self._experiment_name}' already exists")
|
|
74
|
+
self._experiment_id = experiment.experiment_id
|
|
75
|
+
if experiment.lifecycle_stage == "deleted":
|
|
76
|
+
self._logger.info(f"Restoring deleted experiment")
|
|
77
|
+
self._tracking_client.restore_experiment(self._experiment_id)
|
|
78
|
+
else:
|
|
79
|
+
experiment_comment_msg = "(does not exist, creating a new one)"
|
|
80
|
+
self._logger.debug(f"Experiment '{self._experiment_name}' does not exist, creating a new one")
|
|
81
|
+
self._experiment_id = self._tracking_client.create_experiment(self._experiment_name)
|
|
82
|
+
|
|
83
|
+
self._logger.info(f"Starting the experiment tracking")
|
|
84
|
+
self._logger.info(f"> experiment name: {self._experiment_name} {experiment_comment_msg}")
|
|
85
|
+
self._logger.info(f"> experiment id: {self._experiment_id}")
|
|
86
|
+
self._logger.info(f"> run name: {run_name}")
|
|
87
|
+
|
|
88
|
+
self._active_run = self._tracking_client.create_run(
|
|
89
|
+
experiment_id=self._experiment_id,
|
|
90
|
+
start_time=None,
|
|
91
|
+
run_name=run_name,
|
|
92
|
+
tags=tags)
|
|
93
|
+
|
|
94
|
+
self._active_run_id = self._active_run.info.run_id
|
|
95
|
+
|
|
96
|
+
def end_run(self):
|
|
97
|
+
if self._disabled:
|
|
98
|
+
return
|
|
99
|
+
|
|
100
|
+
status = RunStatus.to_string(RunStatus.FINISHED)
|
|
101
|
+
self._tracking_client.set_terminated(self._active_run_id, status)
|
|
102
|
+
self._reset_run()
|
|
103
|
+
|
|
104
|
+
def log(self, name: str, value: Any):
|
|
105
|
+
if self._disabled:
|
|
106
|
+
return
|
|
107
|
+
|
|
108
|
+
if isinstance(value, dict):
|
|
109
|
+
self._tracking_client.log_dict(self._active_run_id, value, f"{name}.json")
|
|
110
|
+
|
|
111
|
+
raise ValueError(f"Unsupported logged object type: {type(value)}")
|
|
112
|
+
|
|
113
|
+
def log_param(self, name: str, value: Any):
|
|
114
|
+
if self._disabled:
|
|
115
|
+
return
|
|
116
|
+
|
|
117
|
+
self._store_param(name, value)
|
|
118
|
+
self._tracking_client.log_param(self._active_run_id, name, value)
|
|
119
|
+
|
|
120
|
+
def log_params(self, params: Dict[str, Any]):
|
|
121
|
+
if self._disabled:
|
|
122
|
+
return
|
|
123
|
+
|
|
124
|
+
params_to_store = []
|
|
125
|
+
for key, value in params.items():
|
|
126
|
+
if isinstance(value, list):
|
|
127
|
+
for v in value:
|
|
128
|
+
params_to_store.append(Param(key, str(v)))
|
|
129
|
+
else:
|
|
130
|
+
params_to_store.append(Param(key, str(value)))
|
|
131
|
+
self._store_param(key, value)
|
|
132
|
+
|
|
133
|
+
self._tracking_client.log_batch(self._active_run_id, [], params_to_store)
|
|
134
|
+
|
|
135
|
+
def log_text(self, name: str, value: str):
|
|
136
|
+
if self._disabled:
|
|
137
|
+
return
|
|
138
|
+
try:
|
|
139
|
+
self._tracking_client.log_text(run_id=self._active_run_id, text=value, artifact_file=name)
|
|
140
|
+
except MlflowException as e:
|
|
141
|
+
self._logger.error(f"An error occurred during text logging: {e}")
|
|
142
|
+
self._logger.error(f"> tracking uri: {self._tracking_client.tracking_uri}")
|
|
143
|
+
self._logger.error(f"> artifact uri: {self._active_run.info.artifact_uri}")
|
|
144
|
+
raise e
|
|
145
|
+
|
|
146
|
+
def log_metric(self, name: str, value: float, step: int | None = None, metric_metadata: dict | None = None):
|
|
147
|
+
if self._disabled:
|
|
148
|
+
return
|
|
149
|
+
|
|
150
|
+
metric_value = self._store_metric(name, value, metric_metadata)
|
|
151
|
+
self._tracking_client.log_metric(self._active_run_id, name, metric_value, step)
|
|
152
|
+
|
|
153
|
+
def log_metrics(self, metrics: Dict[str, float | list], step: Optional[int] = None):
|
|
154
|
+
if self._disabled:
|
|
155
|
+
return
|
|
156
|
+
|
|
157
|
+
metrics_to_store = []
|
|
158
|
+
timestamp = int(time.time() * 1000)
|
|
159
|
+
for k, value in metrics.items():
|
|
160
|
+
if isinstance(value, list):
|
|
161
|
+
for v in value:
|
|
162
|
+
metric_value = self._store_metric(k, v)
|
|
163
|
+
metrics_to_store.append(Metric(k, metric_value, timestamp, step or 0))
|
|
164
|
+
else:
|
|
165
|
+
metric_value = self._store_metric(k, value)
|
|
166
|
+
metrics_to_store.append(Metric(k, metric_value, timestamp, step or 0))
|
|
167
|
+
|
|
168
|
+
self._tracking_client.log_batch(self._active_run_id, metrics_to_store)
|
|
169
|
+
|
|
170
|
+
def log_artifact(self, artifact_path: str, save_dir: str = None):
|
|
171
|
+
if self._disabled:
|
|
172
|
+
return
|
|
173
|
+
|
|
174
|
+
if self._artifact_logging:
|
|
175
|
+
self._tracking_client.log_artifact(self._active_run_id, artifact_path, save_dir)
|
|
176
|
+
|
|
177
|
+
def log_figure(self, figure, save_path: str):
|
|
178
|
+
if self._disabled:
|
|
179
|
+
return
|
|
180
|
+
self._tracking_client.log_figure(self._active_run_id, figure, save_path)
|
|
181
|
+
|
|
182
|
+
def set_run_tag(self, name: str, value: str | int | float):
|
|
183
|
+
self._store_tag(name, value, run_name=self.run_name)
|
|
184
|
+
self._tracking_client.set_tag(self._active_run_id, name, value)
|
|
185
|
+
|
|
186
|
+
def set_experiment_tag(self, name: str, value: str | int | float):
|
|
187
|
+
self._store_tag(name, value)
|
|
188
|
+
self._tracking_client.set_experiment_tag(self._experiment_id, name, value)
|
|
189
|
+
|
|
190
|
+
@staticmethod
|
|
191
|
+
def _build_tracking_client(tracking_uri: str) -> MlflowClient:
|
|
192
|
+
tracking_client = mlflow.tracking.MlflowClient(tracking_uri=tracking_uri, registry_uri=tracking_uri)
|
|
193
|
+
return tracking_client
|
|
194
|
+
|
|
195
|
+
def _reset_run(self):
|
|
196
|
+
self._active_run = None
|
|
197
|
+
self._active_run_id = None
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def play():
|
|
201
|
+
from toolchemy.utils.datestimes import current_datetime_str
|
|
202
|
+
from toolchemy.ai.prompter import PrompterMLflow
|
|
203
|
+
from toolchemy.utils import Locations
|
|
204
|
+
|
|
205
|
+
locations = Locations()
|
|
206
|
+
tracker = MLFlowTracker("http://hal:5000", f"test-{current_datetime_str()}")
|
|
207
|
+
tracker.start_run()
|
|
208
|
+
prompter = PrompterMLflow(locations.in_resources("tests/prompts_mlflow"))
|
|
209
|
+
print(prompter.render("test_prompt", foo="foo1", bar="bar1"))
|
|
210
|
+
tracker.log_param("param1", "param1value")
|
|
211
|
+
tracker.log_param("param2", 2)
|
|
212
|
+
tracker.log_metric("metric1", 1.0)
|
|
213
|
+
tracker.log_metric("metric2", 2.0, metric_metadata={"info": "metric 2 metadata"})
|
|
214
|
+
tracker.log_text("text_test", "some longer piece of text")
|
|
215
|
+
print(prompter.render("test_prompt", foo="foo1", bar="bar1"))
|
|
216
|
+
tracker.end_run()
|
|
217
|
+
print(prompter.render("test_prompt", foo="foo1", bar="bar1"))
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
if __name__ == "__main__":
|
|
221
|
+
play()
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
from typing import Dict, Optional, Any
|
|
2
|
+
from neptune_scale import Run
|
|
3
|
+
|
|
4
|
+
from toolchemy.ai.trackers.common import TrackerBase
|
|
5
|
+
from toolchemy.utils.datestimes import current_datetime_str
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class NeptuneAITracker(TrackerBase):
|
|
9
|
+
def __init__(self, project_name: str, experiment_name: str, api_token: str, with_artifact_logging: bool = True,
|
|
10
|
+
disabled: bool = False):
|
|
11
|
+
super().__init__(experiment_name=experiment_name, with_artifact_logging=with_artifact_logging, disabled=disabled)
|
|
12
|
+
self._api_token = api_token
|
|
13
|
+
self._project_name = project_name
|
|
14
|
+
self._active_run: Run | None = None
|
|
15
|
+
self._active_run_id: str | None = None
|
|
16
|
+
|
|
17
|
+
@property
|
|
18
|
+
def run_name(self) -> str:
|
|
19
|
+
if not self._active_run or not self._active_run_id:
|
|
20
|
+
raise RuntimeError("There is no active run!")
|
|
21
|
+
return self._active_run_id
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def start_run(
|
|
25
|
+
self, run_id: str = None,
|
|
26
|
+
run_name: str = None,
|
|
27
|
+
parent_run_id: str = None,
|
|
28
|
+
user_specified_tags: Dict[str, str] = None
|
|
29
|
+
):
|
|
30
|
+
if self._disabled:
|
|
31
|
+
return
|
|
32
|
+
if self._active_run or self._active_run_id:
|
|
33
|
+
raise RuntimeError(f"Cannot start a new run, there is already an active run")
|
|
34
|
+
if run_name is not None:
|
|
35
|
+
self._logger.warning(f"Neptune tracker uses 'run_id' as the run name. Use 'run_id' for the custom run name.")
|
|
36
|
+
if run_id is None:
|
|
37
|
+
run_id = run_name
|
|
38
|
+
|
|
39
|
+
if run_id is None:
|
|
40
|
+
run_id = self._generate_run_name()
|
|
41
|
+
|
|
42
|
+
self._active_run_id = run_id
|
|
43
|
+
|
|
44
|
+
self._active_run = Run(
|
|
45
|
+
project=self._project_name,
|
|
46
|
+
api_token=self._api_token,
|
|
47
|
+
experiment_name=self._experiment_name,
|
|
48
|
+
run_id=self._active_run_id,
|
|
49
|
+
enable_console_log_capture=True,
|
|
50
|
+
)
|
|
51
|
+
self._logger.info(f"Neptune tracking run started. Experiment name: {self.experiment_name}")
|
|
52
|
+
|
|
53
|
+
def end_run(self):
|
|
54
|
+
if self._disabled:
|
|
55
|
+
return
|
|
56
|
+
if self._active_run is None:
|
|
57
|
+
raise ValueError(f"No active run to stop")
|
|
58
|
+
self._active_run.close()
|
|
59
|
+
self._active_run = None
|
|
60
|
+
self._active_run_id = None
|
|
61
|
+
|
|
62
|
+
def log(self, name: str, value: Any):
|
|
63
|
+
if self._disabled:
|
|
64
|
+
return
|
|
65
|
+
raise NotImplemented()
|
|
66
|
+
|
|
67
|
+
def log_param(self, name: str, val, step: Optional[int] = None):
|
|
68
|
+
if self._disabled:
|
|
69
|
+
return
|
|
70
|
+
self._store_param(name, val)
|
|
71
|
+
self._active_run.log_configs({name: val})
|
|
72
|
+
|
|
73
|
+
def log_params(self, params: Dict[str, Any]):
|
|
74
|
+
if self._disabled:
|
|
75
|
+
return
|
|
76
|
+
for param_name, param_value in params.items():
|
|
77
|
+
self._store_param(param_name, param_value)
|
|
78
|
+
self._active_run.log_configs(params)
|
|
79
|
+
|
|
80
|
+
def log_text(self, name: str, value: str):
|
|
81
|
+
if self._disabled:
|
|
82
|
+
return
|
|
83
|
+
self._active_run.log_configs({name: value})
|
|
84
|
+
|
|
85
|
+
def log_metric(self, name: str, value: float, step: int | None = None, metric_metadata: dict | None = None):
|
|
86
|
+
if self._disabled:
|
|
87
|
+
return
|
|
88
|
+
self.log_metrics({name: value}, step)
|
|
89
|
+
|
|
90
|
+
def log_metrics(self, metrics: Dict[str, float | list], step: Optional[int] = None):
|
|
91
|
+
if self._disabled:
|
|
92
|
+
return
|
|
93
|
+
|
|
94
|
+
for k, v in metrics.items():
|
|
95
|
+
metric_value = self._store_metric(k, v)
|
|
96
|
+
metrics[k] = metric_value
|
|
97
|
+
|
|
98
|
+
self._active_run.log_metrics(metrics, step=step)
|
|
99
|
+
|
|
100
|
+
def log_artifact(self, artifact_path: str, save_dir: str = None):
|
|
101
|
+
if self._disabled:
|
|
102
|
+
return
|
|
103
|
+
if not self._artifact_logging:
|
|
104
|
+
return
|
|
105
|
+
if save_dir is None:
|
|
106
|
+
save_dir = ""
|
|
107
|
+
|
|
108
|
+
self._active_run.assign_files({save_dir: artifact_path})
|
|
109
|
+
|
|
110
|
+
def log_figure(self, figure, save_path: str):
|
|
111
|
+
if self._disabled:
|
|
112
|
+
return
|
|
113
|
+
if save_path is None:
|
|
114
|
+
save_path = ""
|
|
115
|
+
|
|
116
|
+
self._active_run.assign_files({save_path: figure})
|
|
117
|
+
|
|
118
|
+
def set_run_tag(self, name: str, value: str | int | float):
|
|
119
|
+
self._store_tag(name, value, run_name=self.run_name)
|
|
120
|
+
self._active_run.add_tags([f"{name}__{value}"])
|
|
121
|
+
|
|
122
|
+
def set_experiment_tag(self, name: str, value: str | int | float):
|
|
123
|
+
self._store_tag(name, value)
|
|
124
|
+
if self._active_run:
|
|
125
|
+
self.set_run_tag(name, value)
|
|
126
|
+
|
|
127
|
+
def get_id(self):
|
|
128
|
+
if self._disabled:
|
|
129
|
+
return ""
|
|
130
|
+
if not self._active_run:
|
|
131
|
+
raise ValueError(f"No active run")
|
|
132
|
+
return self._active_run
|
|
133
|
+
|
|
134
|
+
def _generate_run_name(self) -> str:
|
|
135
|
+
return f"{self.experiment_name}__{current_datetime_str('%Y_%m_%d__%H_%M_%S')}"
|
toolchemy/db/lightdb.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import uuid
|
|
3
|
+
from tqdm import tqdm
|
|
4
|
+
from collections import defaultdict
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from enum import Enum
|
|
8
|
+
from typing import Any
|
|
9
|
+
from tinydb import TinyDB, Query
|
|
10
|
+
|
|
11
|
+
from toolchemy.utils.logger import get_logger
|
|
12
|
+
from toolchemy.utils.utils import hash_dict, pp
|
|
13
|
+
from toolchemy.utils.cacher import Cacher, ICacher
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class NotFoundError(Exception):
|
|
17
|
+
pass
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class FilterOp(Enum):
|
|
21
|
+
GREATER = 1
|
|
22
|
+
GREATER_OR_EQUAL = 2
|
|
23
|
+
LESS = 3
|
|
24
|
+
LESS_OR_EQUAL = 4
|
|
25
|
+
EQUAL = 5
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class Filter:
|
|
30
|
+
key: str
|
|
31
|
+
value: Any
|
|
32
|
+
op: FilterOp = FilterOp.EQUAL
|
|
33
|
+
|
|
34
|
+
def __str__(self) -> str:
|
|
35
|
+
return f"Filter(key={self.key}, op={self.op.name}, value={self.value})"
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class ILightDB(ABC):
|
|
39
|
+
ID_FIELD = "id_"
|
|
40
|
+
HASH_FIELD = "hash_"
|
|
41
|
+
|
|
42
|
+
@abstractmethod
|
|
43
|
+
def insert(self, doc: dict) -> str:
|
|
44
|
+
pass
|
|
45
|
+
|
|
46
|
+
@abstractmethod
|
|
47
|
+
def insert_batch(self, docs: list[dict]) -> list[str]:
|
|
48
|
+
pass
|
|
49
|
+
|
|
50
|
+
@abstractmethod
|
|
51
|
+
def update(self, doc: dict) -> str:
|
|
52
|
+
pass
|
|
53
|
+
|
|
54
|
+
@abstractmethod
|
|
55
|
+
def upsert(self, doc: dict) -> str:
|
|
56
|
+
pass
|
|
57
|
+
|
|
58
|
+
@abstractmethod
|
|
59
|
+
def retrieve(self, doc_id: str) -> dict | None:
|
|
60
|
+
pass
|
|
61
|
+
|
|
62
|
+
@abstractmethod
|
|
63
|
+
def search(self, query_filter: Filter) -> list[dict]:
|
|
64
|
+
pass
|
|
65
|
+
|
|
66
|
+
@abstractmethod
|
|
67
|
+
def all(self) -> list[dict]:
|
|
68
|
+
pass
|
|
69
|
+
|
|
70
|
+
@abstractmethod
|
|
71
|
+
def remove(self, ids: list[str]) -> int:
|
|
72
|
+
pass
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class LightTinyDB(ILightDB):
|
|
76
|
+
def __init__(self, db_file_path: str, indexes: list[str] | None = None, log_level: int = logging.INFO, cacher: ICacher | None = None):
|
|
77
|
+
self._logger = get_logger(level=log_level)
|
|
78
|
+
self._db = TinyDB(db_file_path)
|
|
79
|
+
|
|
80
|
+
self._cacher = cacher or Cacher()
|
|
81
|
+
|
|
82
|
+
self._indexes = {}
|
|
83
|
+
|
|
84
|
+
for index_name in (indexes or []):
|
|
85
|
+
self._create_and_rebuild_index(index_name)
|
|
86
|
+
|
|
87
|
+
def all(self) -> list[dict]:
|
|
88
|
+
return [dict(document) for document in self._db.all()]
|
|
89
|
+
|
|
90
|
+
def insert(self, doc: dict) -> str:
|
|
91
|
+
doc_prepared = self._prepare_doc_for_store(doc)
|
|
92
|
+
cache_key = self._cacher.create_cache_key("exists", [doc])
|
|
93
|
+
if not self._cacher.exists(cache_key):
|
|
94
|
+
self._logger.debug(f"Inserting doc: {doc_prepared}")
|
|
95
|
+
self._db.insert(doc_prepared)
|
|
96
|
+
self._handle_index(doc_prepared)
|
|
97
|
+
self._cacher.set(cache_key, doc_prepared[self.ID_FIELD])
|
|
98
|
+
return doc_prepared[self.ID_FIELD]
|
|
99
|
+
|
|
100
|
+
def insert_batch(self, docs: list[dict]) -> list[str]:
|
|
101
|
+
docs_prepared = [self._prepare_doc_for_store(doc) for doc in docs]
|
|
102
|
+
docs_prepared = [doc for doc in docs_prepared if not self._cacher.exists(self._cacher.create_cache_key("exists", [doc]))]
|
|
103
|
+
self._logger.debug(f"Inserting doc batch: {docs_prepared}")
|
|
104
|
+
self._db.insert_multiple(docs_prepared)
|
|
105
|
+
for doc_prepared in docs_prepared:
|
|
106
|
+
self._handle_index(doc_prepared)
|
|
107
|
+
for doc in docs_prepared:
|
|
108
|
+
self._cacher.set(self._cacher.create_cache_key("exists", [doc]), doc[self.ID_FIELD])
|
|
109
|
+
return [doc[self.ID_FIELD] for doc in docs_prepared]
|
|
110
|
+
|
|
111
|
+
def update(self, doc: dict) -> str:
|
|
112
|
+
doc = self._prepare_doc_for_store(doc)
|
|
113
|
+
q = Query()
|
|
114
|
+
self._db.update(doc, q.id_ == doc[self.ID_FIELD])
|
|
115
|
+
self._handle_index(doc)
|
|
116
|
+
return doc[self.ID_FIELD]
|
|
117
|
+
|
|
118
|
+
def upsert(self, doc: dict) -> str:
|
|
119
|
+
doc = self._prepare_doc_for_store(doc)
|
|
120
|
+
q = Query()
|
|
121
|
+
self._db.upsert(doc, q.id_ == doc[self.ID_FIELD])
|
|
122
|
+
self._handle_index(doc)
|
|
123
|
+
return doc[self.ID_FIELD]
|
|
124
|
+
|
|
125
|
+
def retrieve(self, doc_id: str) -> dict | None:
|
|
126
|
+
docs = self.search(Filter(self.ID_FIELD, doc_id))
|
|
127
|
+
if len(docs) < 1:
|
|
128
|
+
return None
|
|
129
|
+
return docs[0]
|
|
130
|
+
|
|
131
|
+
def search(self, query_filter: Filter) -> list[dict]:
|
|
132
|
+
self._logger.debug(f"Searching with filter: {str(query_filter)}")
|
|
133
|
+
if query_filter.op == FilterOp.EQUAL and query_filter.value is not None:
|
|
134
|
+
if self._has_index(query_filter.key):
|
|
135
|
+
return self._search_index(query_filter.key, query_filter.value)
|
|
136
|
+
tinydb_query = Query()
|
|
137
|
+
documents = self._db.search(tinydb_query[query_filter.key].test(self._filter_to_test_fn(query_filter)))
|
|
138
|
+
self._logger.debug(f"> documents found: {len(documents)}")
|
|
139
|
+
return [self._prepare_doc_for_return(dict(document)) for document in documents]
|
|
140
|
+
|
|
141
|
+
def _create_and_rebuild_index(self, field_name: str) -> None:
|
|
142
|
+
self._logger.debug(f"Creating index for: '{field_name}'...")
|
|
143
|
+
self._indexes[field_name] = defaultdict(list)
|
|
144
|
+
docs = self._db.all()
|
|
145
|
+
for doc in tqdm(docs, desc=f"recreating index '{field_name}':'"):
|
|
146
|
+
self._add_to_index(field_name, doc)
|
|
147
|
+
self._logger.debug(f"> indexed {len(docs)} documents")
|
|
148
|
+
|
|
149
|
+
def _add_to_index(self, field_name: str, doc: dict) -> None:
|
|
150
|
+
# self._logger.debug(f"Adding to index '{field_name}': {doc.get(field_name, 'MISSING')}")
|
|
151
|
+
if field_name not in doc:
|
|
152
|
+
return
|
|
153
|
+
self._indexes[field_name][doc[field_name]].append(doc)
|
|
154
|
+
|
|
155
|
+
def _remove_from_index(self, field_name: str, doc: dict) -> None:
|
|
156
|
+
if field_name not in doc:
|
|
157
|
+
return
|
|
158
|
+
for indexed_doc in (self._indexes[field_name][doc[field_name]] or []):
|
|
159
|
+
if indexed_doc[self.ID_FIELD] == doc[self.ID_FIELD]:
|
|
160
|
+
self._indexes[field_name][doc[field_name]].remove(indexed_doc)
|
|
161
|
+
|
|
162
|
+
def _handle_index(self, doc: dict, remove: bool = False) -> None:
|
|
163
|
+
if remove:
|
|
164
|
+
self._handle_index_remove(doc)
|
|
165
|
+
else:
|
|
166
|
+
self._handle_index_add(doc)
|
|
167
|
+
|
|
168
|
+
def _handle_index_add(self, doc: dict) -> None:
|
|
169
|
+
for field_name in doc.keys():
|
|
170
|
+
if field_name in self._indexes:
|
|
171
|
+
self._add_to_index(field_name, doc)
|
|
172
|
+
|
|
173
|
+
def _handle_index_remove(self, doc: dict):
|
|
174
|
+
self._logger.debug(f"_handler_index_remove| doc: {doc}")
|
|
175
|
+
if len(doc.keys()) == 1 and list(doc.keys())[0] == self.ID_FIELD:
|
|
176
|
+
self._logger.debug(f"> the doc has a single key, trying to get the full document")
|
|
177
|
+
doc = self.retrieve(doc[self.ID_FIELD])
|
|
178
|
+
if doc is None:
|
|
179
|
+
return
|
|
180
|
+
self._logger.debug(f"> the full document: {doc}")
|
|
181
|
+
for doc_field in doc.keys():
|
|
182
|
+
if doc_field in self._indexes:
|
|
183
|
+
self._remove_from_index(doc_field, doc)
|
|
184
|
+
|
|
185
|
+
def _has_index(self, key: str) -> bool:
|
|
186
|
+
if key not in self._indexes:
|
|
187
|
+
self._logger.debug(f"> there is no index: '{key}'")
|
|
188
|
+
return False
|
|
189
|
+
return True
|
|
190
|
+
|
|
191
|
+
def _search_index(self, key: str, value: Any) -> list[dict]:
|
|
192
|
+
if value not in self._indexes[key]:
|
|
193
|
+
return []
|
|
194
|
+
return self._indexes[key][value]
|
|
195
|
+
|
|
196
|
+
def _filter_to_test_fn(self, query_filter: Filter):
|
|
197
|
+
def test_func(val):
|
|
198
|
+
if not isinstance(val, str) and not isinstance(val, int) and not isinstance(val, float):
|
|
199
|
+
return False
|
|
200
|
+
if query_filter.op == FilterOp.GREATER:
|
|
201
|
+
return val > query_filter.value
|
|
202
|
+
if query_filter.op == FilterOp.GREATER_OR_EQUAL:
|
|
203
|
+
return val >= query_filter.value
|
|
204
|
+
if query_filter.op == FilterOp.LESS:
|
|
205
|
+
return val < query_filter.value
|
|
206
|
+
if query_filter.op == FilterOp.LESS_OR_EQUAL:
|
|
207
|
+
return val <= query_filter.value
|
|
208
|
+
if query_filter.op == FilterOp.EQUAL:
|
|
209
|
+
return val == query_filter.value
|
|
210
|
+
raise ValueError(f"Unknown operator {query_filter.op}")
|
|
211
|
+
return test_func
|
|
212
|
+
|
|
213
|
+
def remove(self, ids: list[str] | str) -> int:
|
|
214
|
+
if isinstance(ids, str):
|
|
215
|
+
ids = [ids]
|
|
216
|
+
removed_elements_ids = []
|
|
217
|
+
for doc_id in ids:
|
|
218
|
+
self._handle_index({self.ID_FIELD: doc_id}, remove=True)
|
|
219
|
+
q = Query()
|
|
220
|
+
removed_elements = self._db.remove(q.id_ == doc_id)
|
|
221
|
+
removed_elements_ids.extend(removed_elements)
|
|
222
|
+
return len(list(set(removed_elements_ids)))
|
|
223
|
+
|
|
224
|
+
def _prepare_doc_for_store(self, doc: dict) -> dict:
|
|
225
|
+
self._ensure_created_at(doc)
|
|
226
|
+
return self._ensure_hash(self._ensure_id(doc))
|
|
227
|
+
|
|
228
|
+
def _prepare_doc_for_return(self, doc: dict) -> dict:
|
|
229
|
+
doc_copy = doc.copy()
|
|
230
|
+
self._ensure_created_at(doc_copy)
|
|
231
|
+
return doc_copy
|
|
232
|
+
|
|
233
|
+
def _ensure_created_at(self, doc: dict):
|
|
234
|
+
if "created_at" not in doc:
|
|
235
|
+
err_msg = f"There is no created_at in doc!\n{pp(doc, print_msg=False)}"
|
|
236
|
+
self._logger.error(err_msg)
|
|
237
|
+
raise ValueError(err_msg)
|
|
238
|
+
|
|
239
|
+
def _ensure_id(self, doc: dict) -> dict:
|
|
240
|
+
if self.ID_FIELD not in doc:
|
|
241
|
+
doc[self.ID_FIELD] = self._generate_id()
|
|
242
|
+
return doc
|
|
243
|
+
|
|
244
|
+
def _generate_id(self) -> str:
|
|
245
|
+
return str(uuid.uuid4())
|
|
246
|
+
|
|
247
|
+
def _ensure_hash(self, doc: dict) -> dict:
|
|
248
|
+
if self.HASH_FIELD in doc and doc[self.HASH_FIELD]:
|
|
249
|
+
self._logger.debug(f"the doc already has the hash property, keeping it as is: {doc[self.HASH_FIELD]}")
|
|
250
|
+
return doc
|
|
251
|
+
doc[self.HASH_FIELD] = self._generate_hash(doc)
|
|
252
|
+
return doc
|
|
253
|
+
|
|
254
|
+
def _generate_hash(self, doc: dict) -> str:
|
|
255
|
+
doc_copy = doc.copy()
|
|
256
|
+
if self.ID_FIELD in doc_copy:
|
|
257
|
+
del doc_copy[self.ID_FIELD]
|
|
258
|
+
if self.HASH_FIELD in doc_copy:
|
|
259
|
+
del doc_copy[self.HASH_FIELD]
|
|
260
|
+
return hash_dict(doc_copy)
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from .cacher import ICacher, CacherPickle, Cacher, CacheEntryDoesNotExistError, DummyCacher
|
|
2
|
+
from .locations import Locations
|
|
3
|
+
from .logger import get_logger
|
|
4
|
+
from .utils import pp, pp_cast, ff, hash_dict, to_json, truncate
|
|
5
|
+
from .timer import Timer
|
|
6
|
+
from .at_exit_collector import ICollectable, AtExitCollector
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"ICollectable", "AtExitCollector",
|
|
11
|
+
"ICacher",
|
|
12
|
+
"CacherPickle",
|
|
13
|
+
"Cacher",
|
|
14
|
+
"CacheEntryDoesNotExistError",
|
|
15
|
+
"DummyCacher",
|
|
16
|
+
"get_logger",
|
|
17
|
+
"Locations",
|
|
18
|
+
"pp", "pp_cast", "ff", "hash_dict", "to_json", "truncate",
|
|
19
|
+
"Timer"]
|