toolchemy 0.2.185__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. toolchemy/__main__.py +9 -0
  2. toolchemy/ai/clients/__init__.py +20 -0
  3. toolchemy/ai/clients/common.py +429 -0
  4. toolchemy/ai/clients/dummy_model_client.py +61 -0
  5. toolchemy/ai/clients/factory.py +37 -0
  6. toolchemy/ai/clients/gemini_client.py +48 -0
  7. toolchemy/ai/clients/ollama_client.py +58 -0
  8. toolchemy/ai/clients/openai_client.py +76 -0
  9. toolchemy/ai/clients/pricing.py +66 -0
  10. toolchemy/ai/clients/whisper_client.py +141 -0
  11. toolchemy/ai/prompter.py +124 -0
  12. toolchemy/ai/trackers/__init__.py +5 -0
  13. toolchemy/ai/trackers/common.py +216 -0
  14. toolchemy/ai/trackers/mlflow_tracker.py +221 -0
  15. toolchemy/ai/trackers/neptune_tracker.py +135 -0
  16. toolchemy/db/lightdb.py +260 -0
  17. toolchemy/utils/__init__.py +19 -0
  18. toolchemy/utils/at_exit_collector.py +109 -0
  19. toolchemy/utils/cacher/__init__.py +20 -0
  20. toolchemy/utils/cacher/cacher_diskcache.py +121 -0
  21. toolchemy/utils/cacher/cacher_pickle.py +152 -0
  22. toolchemy/utils/cacher/cacher_shelve.py +196 -0
  23. toolchemy/utils/cacher/common.py +174 -0
  24. toolchemy/utils/datestimes.py +77 -0
  25. toolchemy/utils/locations.py +111 -0
  26. toolchemy/utils/logger.py +76 -0
  27. toolchemy/utils/timer.py +23 -0
  28. toolchemy/utils/utils.py +168 -0
  29. toolchemy/vision/__init__.py +5 -0
  30. toolchemy/vision/caption_overlay.py +77 -0
  31. toolchemy/vision/image.py +89 -0
  32. toolchemy-0.2.185.dist-info/METADATA +25 -0
  33. toolchemy-0.2.185.dist-info/RECORD +36 -0
  34. toolchemy-0.2.185.dist-info/WHEEL +4 -0
  35. toolchemy-0.2.185.dist-info/entry_points.txt +3 -0
  36. toolchemy-0.2.185.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,221 @@
1
+ import time
2
+ from typing import Dict, Optional, Any
3
+
4
+ import mlflow
5
+ mlflow.autolog(disable=True)
6
+ from mlflow.client import MlflowClient
7
+ from mlflow.entities import RunStatus, Metric, Param
8
+ from mlflow.tracking.context.registry import resolve_tags
9
+ from mlflow.utils.mlflow_tags import MLFLOW_PARENT_RUN_ID, MLFLOW_RUN_NAME
10
+ from mlflow import MlflowException
11
+
12
+ from toolchemy.ai.trackers.common import TrackerBase
13
+ from toolchemy.utils.logger import get_logger
14
+
15
+
16
+ class MLFlowTracker(TrackerBase):
17
+ def __init__(self, tracking_uri: str, experiment_name: str, with_artifact_logging=True,
18
+ tracking_client: MlflowClient | None = None):
19
+ super().__init__(experiment_name, with_artifact_logging)
20
+ self._tracking_client = None
21
+ self._active_run = None
22
+ self._active_run_id = None
23
+ self._experiment_id = None
24
+ self._reset_run()
25
+
26
+ if tracking_client:
27
+ self._tracking_client = tracking_client
28
+ else:
29
+ self._tracking_client = self._build_tracking_client(tracking_uri)
30
+
31
+ mlflow.set_tracking_uri(tracking_uri)
32
+
33
+ logger = get_logger()
34
+ logger.info(f"Mlflow tracker created")
35
+ logger.info(f"> tracking uri: {tracking_uri}")
36
+ logger.info(f"> artifact logging: {self._artifact_logging}")
37
+
38
+ @property
39
+ def run_name(self) -> str:
40
+ if not self._active_run:
41
+ raise RuntimeError("There is no active run!")
42
+ return self._active_run.info.run_name
43
+
44
+ def start_run(
45
+ self, run_id: str = None,
46
+ run_name: str = None,
47
+ parent_run_id: str = None,
48
+ user_specified_tags: Dict[str, str] = None
49
+ ):
50
+ """
51
+ Starts the run
52
+
53
+ :param run_id: If specified, get the run with the specified ID and log parameters and metrics under that run.
54
+ :param run_name: Name of new run. Used only when run_id is unspecified.
55
+ :param parent_run_id: If specified: current run will be nested into parent_run_id
56
+ :param user_specified_tags: dict of with custom tags
57
+ """
58
+ if self._disabled:
59
+ return
60
+
61
+ if user_specified_tags is None:
62
+ user_specified_tags = {}
63
+ if parent_run_id is not None:
64
+ user_specified_tags[MLFLOW_PARENT_RUN_ID] = parent_run_id
65
+ if run_name:
66
+ user_specified_tags[MLFLOW_RUN_NAME] = run_name
67
+
68
+ tags = resolve_tags(user_specified_tags)
69
+
70
+ experiment = self._tracking_client.get_experiment_by_name(self._experiment_name)
71
+ if experiment:
72
+ experiment_comment_msg = "(already exists)"
73
+ self._logger.debug(f"Experiment '{self._experiment_name}' already exists")
74
+ self._experiment_id = experiment.experiment_id
75
+ if experiment.lifecycle_stage == "deleted":
76
+ self._logger.info(f"Restoring deleted experiment")
77
+ self._tracking_client.restore_experiment(self._experiment_id)
78
+ else:
79
+ experiment_comment_msg = "(does not exist, creating a new one)"
80
+ self._logger.debug(f"Experiment '{self._experiment_name}' does not exist, creating a new one")
81
+ self._experiment_id = self._tracking_client.create_experiment(self._experiment_name)
82
+
83
+ self._logger.info(f"Starting the experiment tracking")
84
+ self._logger.info(f"> experiment name: {self._experiment_name} {experiment_comment_msg}")
85
+ self._logger.info(f"> experiment id: {self._experiment_id}")
86
+ self._logger.info(f"> run name: {run_name}")
87
+
88
+ self._active_run = self._tracking_client.create_run(
89
+ experiment_id=self._experiment_id,
90
+ start_time=None,
91
+ run_name=run_name,
92
+ tags=tags)
93
+
94
+ self._active_run_id = self._active_run.info.run_id
95
+
96
+ def end_run(self):
97
+ if self._disabled:
98
+ return
99
+
100
+ status = RunStatus.to_string(RunStatus.FINISHED)
101
+ self._tracking_client.set_terminated(self._active_run_id, status)
102
+ self._reset_run()
103
+
104
+ def log(self, name: str, value: Any):
105
+ if self._disabled:
106
+ return
107
+
108
+ if isinstance(value, dict):
109
+ self._tracking_client.log_dict(self._active_run_id, value, f"{name}.json")
110
+
111
+ raise ValueError(f"Unsupported logged object type: {type(value)}")
112
+
113
+ def log_param(self, name: str, value: Any):
114
+ if self._disabled:
115
+ return
116
+
117
+ self._store_param(name, value)
118
+ self._tracking_client.log_param(self._active_run_id, name, value)
119
+
120
+ def log_params(self, params: Dict[str, Any]):
121
+ if self._disabled:
122
+ return
123
+
124
+ params_to_store = []
125
+ for key, value in params.items():
126
+ if isinstance(value, list):
127
+ for v in value:
128
+ params_to_store.append(Param(key, str(v)))
129
+ else:
130
+ params_to_store.append(Param(key, str(value)))
131
+ self._store_param(key, value)
132
+
133
+ self._tracking_client.log_batch(self._active_run_id, [], params_to_store)
134
+
135
+ def log_text(self, name: str, value: str):
136
+ if self._disabled:
137
+ return
138
+ try:
139
+ self._tracking_client.log_text(run_id=self._active_run_id, text=value, artifact_file=name)
140
+ except MlflowException as e:
141
+ self._logger.error(f"An error occurred during text logging: {e}")
142
+ self._logger.error(f"> tracking uri: {self._tracking_client.tracking_uri}")
143
+ self._logger.error(f"> artifact uri: {self._active_run.info.artifact_uri}")
144
+ raise e
145
+
146
+ def log_metric(self, name: str, value: float, step: int | None = None, metric_metadata: dict | None = None):
147
+ if self._disabled:
148
+ return
149
+
150
+ metric_value = self._store_metric(name, value, metric_metadata)
151
+ self._tracking_client.log_metric(self._active_run_id, name, metric_value, step)
152
+
153
+ def log_metrics(self, metrics: Dict[str, float | list], step: Optional[int] = None):
154
+ if self._disabled:
155
+ return
156
+
157
+ metrics_to_store = []
158
+ timestamp = int(time.time() * 1000)
159
+ for k, value in metrics.items():
160
+ if isinstance(value, list):
161
+ for v in value:
162
+ metric_value = self._store_metric(k, v)
163
+ metrics_to_store.append(Metric(k, metric_value, timestamp, step or 0))
164
+ else:
165
+ metric_value = self._store_metric(k, value)
166
+ metrics_to_store.append(Metric(k, metric_value, timestamp, step or 0))
167
+
168
+ self._tracking_client.log_batch(self._active_run_id, metrics_to_store)
169
+
170
+ def log_artifact(self, artifact_path: str, save_dir: str = None):
171
+ if self._disabled:
172
+ return
173
+
174
+ if self._artifact_logging:
175
+ self._tracking_client.log_artifact(self._active_run_id, artifact_path, save_dir)
176
+
177
+ def log_figure(self, figure, save_path: str):
178
+ if self._disabled:
179
+ return
180
+ self._tracking_client.log_figure(self._active_run_id, figure, save_path)
181
+
182
+ def set_run_tag(self, name: str, value: str | int | float):
183
+ self._store_tag(name, value, run_name=self.run_name)
184
+ self._tracking_client.set_tag(self._active_run_id, name, value)
185
+
186
+ def set_experiment_tag(self, name: str, value: str | int | float):
187
+ self._store_tag(name, value)
188
+ self._tracking_client.set_experiment_tag(self._experiment_id, name, value)
189
+
190
+ @staticmethod
191
+ def _build_tracking_client(tracking_uri: str) -> MlflowClient:
192
+ tracking_client = mlflow.tracking.MlflowClient(tracking_uri=tracking_uri, registry_uri=tracking_uri)
193
+ return tracking_client
194
+
195
+ def _reset_run(self):
196
+ self._active_run = None
197
+ self._active_run_id = None
198
+
199
+
200
+ def play():
201
+ from toolchemy.utils.datestimes import current_datetime_str
202
+ from toolchemy.ai.prompter import PrompterMLflow
203
+ from toolchemy.utils import Locations
204
+
205
+ locations = Locations()
206
+ tracker = MLFlowTracker("http://hal:5000", f"test-{current_datetime_str()}")
207
+ tracker.start_run()
208
+ prompter = PrompterMLflow(locations.in_resources("tests/prompts_mlflow"))
209
+ print(prompter.render("test_prompt", foo="foo1", bar="bar1"))
210
+ tracker.log_param("param1", "param1value")
211
+ tracker.log_param("param2", 2)
212
+ tracker.log_metric("metric1", 1.0)
213
+ tracker.log_metric("metric2", 2.0, metric_metadata={"info": "metric 2 metadata"})
214
+ tracker.log_text("text_test", "some longer piece of text")
215
+ print(prompter.render("test_prompt", foo="foo1", bar="bar1"))
216
+ tracker.end_run()
217
+ print(prompter.render("test_prompt", foo="foo1", bar="bar1"))
218
+
219
+
220
+ if __name__ == "__main__":
221
+ play()
@@ -0,0 +1,135 @@
1
+ from typing import Dict, Optional, Any
2
+ from neptune_scale import Run
3
+
4
+ from toolchemy.ai.trackers.common import TrackerBase
5
+ from toolchemy.utils.datestimes import current_datetime_str
6
+
7
+
8
+ class NeptuneAITracker(TrackerBase):
9
+ def __init__(self, project_name: str, experiment_name: str, api_token: str, with_artifact_logging: bool = True,
10
+ disabled: bool = False):
11
+ super().__init__(experiment_name=experiment_name, with_artifact_logging=with_artifact_logging, disabled=disabled)
12
+ self._api_token = api_token
13
+ self._project_name = project_name
14
+ self._active_run: Run | None = None
15
+ self._active_run_id: str | None = None
16
+
17
+ @property
18
+ def run_name(self) -> str:
19
+ if not self._active_run or not self._active_run_id:
20
+ raise RuntimeError("There is no active run!")
21
+ return self._active_run_id
22
+
23
+
24
+ def start_run(
25
+ self, run_id: str = None,
26
+ run_name: str = None,
27
+ parent_run_id: str = None,
28
+ user_specified_tags: Dict[str, str] = None
29
+ ):
30
+ if self._disabled:
31
+ return
32
+ if self._active_run or self._active_run_id:
33
+ raise RuntimeError(f"Cannot start a new run, there is already an active run")
34
+ if run_name is not None:
35
+ self._logger.warning(f"Neptune tracker uses 'run_id' as the run name. Use 'run_id' for the custom run name.")
36
+ if run_id is None:
37
+ run_id = run_name
38
+
39
+ if run_id is None:
40
+ run_id = self._generate_run_name()
41
+
42
+ self._active_run_id = run_id
43
+
44
+ self._active_run = Run(
45
+ project=self._project_name,
46
+ api_token=self._api_token,
47
+ experiment_name=self._experiment_name,
48
+ run_id=self._active_run_id,
49
+ enable_console_log_capture=True,
50
+ )
51
+ self._logger.info(f"Neptune tracking run started. Experiment name: {self.experiment_name}")
52
+
53
+ def end_run(self):
54
+ if self._disabled:
55
+ return
56
+ if self._active_run is None:
57
+ raise ValueError(f"No active run to stop")
58
+ self._active_run.close()
59
+ self._active_run = None
60
+ self._active_run_id = None
61
+
62
+ def log(self, name: str, value: Any):
63
+ if self._disabled:
64
+ return
65
+ raise NotImplemented()
66
+
67
+ def log_param(self, name: str, val, step: Optional[int] = None):
68
+ if self._disabled:
69
+ return
70
+ self._store_param(name, val)
71
+ self._active_run.log_configs({name: val})
72
+
73
+ def log_params(self, params: Dict[str, Any]):
74
+ if self._disabled:
75
+ return
76
+ for param_name, param_value in params.items():
77
+ self._store_param(param_name, param_value)
78
+ self._active_run.log_configs(params)
79
+
80
+ def log_text(self, name: str, value: str):
81
+ if self._disabled:
82
+ return
83
+ self._active_run.log_configs({name: value})
84
+
85
+ def log_metric(self, name: str, value: float, step: int | None = None, metric_metadata: dict | None = None):
86
+ if self._disabled:
87
+ return
88
+ self.log_metrics({name: value}, step)
89
+
90
+ def log_metrics(self, metrics: Dict[str, float | list], step: Optional[int] = None):
91
+ if self._disabled:
92
+ return
93
+
94
+ for k, v in metrics.items():
95
+ metric_value = self._store_metric(k, v)
96
+ metrics[k] = metric_value
97
+
98
+ self._active_run.log_metrics(metrics, step=step)
99
+
100
+ def log_artifact(self, artifact_path: str, save_dir: str = None):
101
+ if self._disabled:
102
+ return
103
+ if not self._artifact_logging:
104
+ return
105
+ if save_dir is None:
106
+ save_dir = ""
107
+
108
+ self._active_run.assign_files({save_dir: artifact_path})
109
+
110
+ def log_figure(self, figure, save_path: str):
111
+ if self._disabled:
112
+ return
113
+ if save_path is None:
114
+ save_path = ""
115
+
116
+ self._active_run.assign_files({save_path: figure})
117
+
118
+ def set_run_tag(self, name: str, value: str | int | float):
119
+ self._store_tag(name, value, run_name=self.run_name)
120
+ self._active_run.add_tags([f"{name}__{value}"])
121
+
122
+ def set_experiment_tag(self, name: str, value: str | int | float):
123
+ self._store_tag(name, value)
124
+ if self._active_run:
125
+ self.set_run_tag(name, value)
126
+
127
+ def get_id(self):
128
+ if self._disabled:
129
+ return ""
130
+ if not self._active_run:
131
+ raise ValueError(f"No active run")
132
+ return self._active_run
133
+
134
+ def _generate_run_name(self) -> str:
135
+ return f"{self.experiment_name}__{current_datetime_str('%Y_%m_%d__%H_%M_%S')}"
@@ -0,0 +1,260 @@
1
+ import logging
2
+ import uuid
3
+ from tqdm import tqdm
4
+ from collections import defaultdict
5
+ from abc import ABC, abstractmethod
6
+ from dataclasses import dataclass
7
+ from enum import Enum
8
+ from typing import Any
9
+ from tinydb import TinyDB, Query
10
+
11
+ from toolchemy.utils.logger import get_logger
12
+ from toolchemy.utils.utils import hash_dict, pp
13
+ from toolchemy.utils.cacher import Cacher, ICacher
14
+
15
+
16
+ class NotFoundError(Exception):
17
+ pass
18
+
19
+
20
+ class FilterOp(Enum):
21
+ GREATER = 1
22
+ GREATER_OR_EQUAL = 2
23
+ LESS = 3
24
+ LESS_OR_EQUAL = 4
25
+ EQUAL = 5
26
+
27
+
28
+ @dataclass
29
+ class Filter:
30
+ key: str
31
+ value: Any
32
+ op: FilterOp = FilterOp.EQUAL
33
+
34
+ def __str__(self) -> str:
35
+ return f"Filter(key={self.key}, op={self.op.name}, value={self.value})"
36
+
37
+
38
+ class ILightDB(ABC):
39
+ ID_FIELD = "id_"
40
+ HASH_FIELD = "hash_"
41
+
42
+ @abstractmethod
43
+ def insert(self, doc: dict) -> str:
44
+ pass
45
+
46
+ @abstractmethod
47
+ def insert_batch(self, docs: list[dict]) -> list[str]:
48
+ pass
49
+
50
+ @abstractmethod
51
+ def update(self, doc: dict) -> str:
52
+ pass
53
+
54
+ @abstractmethod
55
+ def upsert(self, doc: dict) -> str:
56
+ pass
57
+
58
+ @abstractmethod
59
+ def retrieve(self, doc_id: str) -> dict | None:
60
+ pass
61
+
62
+ @abstractmethod
63
+ def search(self, query_filter: Filter) -> list[dict]:
64
+ pass
65
+
66
+ @abstractmethod
67
+ def all(self) -> list[dict]:
68
+ pass
69
+
70
+ @abstractmethod
71
+ def remove(self, ids: list[str]) -> int:
72
+ pass
73
+
74
+
75
+ class LightTinyDB(ILightDB):
76
+ def __init__(self, db_file_path: str, indexes: list[str] | None = None, log_level: int = logging.INFO, cacher: ICacher | None = None):
77
+ self._logger = get_logger(level=log_level)
78
+ self._db = TinyDB(db_file_path)
79
+
80
+ self._cacher = cacher or Cacher()
81
+
82
+ self._indexes = {}
83
+
84
+ for index_name in (indexes or []):
85
+ self._create_and_rebuild_index(index_name)
86
+
87
+ def all(self) -> list[dict]:
88
+ return [dict(document) for document in self._db.all()]
89
+
90
+ def insert(self, doc: dict) -> str:
91
+ doc_prepared = self._prepare_doc_for_store(doc)
92
+ cache_key = self._cacher.create_cache_key("exists", [doc])
93
+ if not self._cacher.exists(cache_key):
94
+ self._logger.debug(f"Inserting doc: {doc_prepared}")
95
+ self._db.insert(doc_prepared)
96
+ self._handle_index(doc_prepared)
97
+ self._cacher.set(cache_key, doc_prepared[self.ID_FIELD])
98
+ return doc_prepared[self.ID_FIELD]
99
+
100
+ def insert_batch(self, docs: list[dict]) -> list[str]:
101
+ docs_prepared = [self._prepare_doc_for_store(doc) for doc in docs]
102
+ docs_prepared = [doc for doc in docs_prepared if not self._cacher.exists(self._cacher.create_cache_key("exists", [doc]))]
103
+ self._logger.debug(f"Inserting doc batch: {docs_prepared}")
104
+ self._db.insert_multiple(docs_prepared)
105
+ for doc_prepared in docs_prepared:
106
+ self._handle_index(doc_prepared)
107
+ for doc in docs_prepared:
108
+ self._cacher.set(self._cacher.create_cache_key("exists", [doc]), doc[self.ID_FIELD])
109
+ return [doc[self.ID_FIELD] for doc in docs_prepared]
110
+
111
+ def update(self, doc: dict) -> str:
112
+ doc = self._prepare_doc_for_store(doc)
113
+ q = Query()
114
+ self._db.update(doc, q.id_ == doc[self.ID_FIELD])
115
+ self._handle_index(doc)
116
+ return doc[self.ID_FIELD]
117
+
118
+ def upsert(self, doc: dict) -> str:
119
+ doc = self._prepare_doc_for_store(doc)
120
+ q = Query()
121
+ self._db.upsert(doc, q.id_ == doc[self.ID_FIELD])
122
+ self._handle_index(doc)
123
+ return doc[self.ID_FIELD]
124
+
125
+ def retrieve(self, doc_id: str) -> dict | None:
126
+ docs = self.search(Filter(self.ID_FIELD, doc_id))
127
+ if len(docs) < 1:
128
+ return None
129
+ return docs[0]
130
+
131
+ def search(self, query_filter: Filter) -> list[dict]:
132
+ self._logger.debug(f"Searching with filter: {str(query_filter)}")
133
+ if query_filter.op == FilterOp.EQUAL and query_filter.value is not None:
134
+ if self._has_index(query_filter.key):
135
+ return self._search_index(query_filter.key, query_filter.value)
136
+ tinydb_query = Query()
137
+ documents = self._db.search(tinydb_query[query_filter.key].test(self._filter_to_test_fn(query_filter)))
138
+ self._logger.debug(f"> documents found: {len(documents)}")
139
+ return [self._prepare_doc_for_return(dict(document)) for document in documents]
140
+
141
+ def _create_and_rebuild_index(self, field_name: str) -> None:
142
+ self._logger.debug(f"Creating index for: '{field_name}'...")
143
+ self._indexes[field_name] = defaultdict(list)
144
+ docs = self._db.all()
145
+ for doc in tqdm(docs, desc=f"recreating index '{field_name}':'"):
146
+ self._add_to_index(field_name, doc)
147
+ self._logger.debug(f"> indexed {len(docs)} documents")
148
+
149
+ def _add_to_index(self, field_name: str, doc: dict) -> None:
150
+ # self._logger.debug(f"Adding to index '{field_name}': {doc.get(field_name, 'MISSING')}")
151
+ if field_name not in doc:
152
+ return
153
+ self._indexes[field_name][doc[field_name]].append(doc)
154
+
155
+ def _remove_from_index(self, field_name: str, doc: dict) -> None:
156
+ if field_name not in doc:
157
+ return
158
+ for indexed_doc in (self._indexes[field_name][doc[field_name]] or []):
159
+ if indexed_doc[self.ID_FIELD] == doc[self.ID_FIELD]:
160
+ self._indexes[field_name][doc[field_name]].remove(indexed_doc)
161
+
162
+ def _handle_index(self, doc: dict, remove: bool = False) -> None:
163
+ if remove:
164
+ self._handle_index_remove(doc)
165
+ else:
166
+ self._handle_index_add(doc)
167
+
168
+ def _handle_index_add(self, doc: dict) -> None:
169
+ for field_name in doc.keys():
170
+ if field_name in self._indexes:
171
+ self._add_to_index(field_name, doc)
172
+
173
+ def _handle_index_remove(self, doc: dict):
174
+ self._logger.debug(f"_handler_index_remove| doc: {doc}")
175
+ if len(doc.keys()) == 1 and list(doc.keys())[0] == self.ID_FIELD:
176
+ self._logger.debug(f"> the doc has a single key, trying to get the full document")
177
+ doc = self.retrieve(doc[self.ID_FIELD])
178
+ if doc is None:
179
+ return
180
+ self._logger.debug(f"> the full document: {doc}")
181
+ for doc_field in doc.keys():
182
+ if doc_field in self._indexes:
183
+ self._remove_from_index(doc_field, doc)
184
+
185
+ def _has_index(self, key: str) -> bool:
186
+ if key not in self._indexes:
187
+ self._logger.debug(f"> there is no index: '{key}'")
188
+ return False
189
+ return True
190
+
191
+ def _search_index(self, key: str, value: Any) -> list[dict]:
192
+ if value not in self._indexes[key]:
193
+ return []
194
+ return self._indexes[key][value]
195
+
196
+ def _filter_to_test_fn(self, query_filter: Filter):
197
+ def test_func(val):
198
+ if not isinstance(val, str) and not isinstance(val, int) and not isinstance(val, float):
199
+ return False
200
+ if query_filter.op == FilterOp.GREATER:
201
+ return val > query_filter.value
202
+ if query_filter.op == FilterOp.GREATER_OR_EQUAL:
203
+ return val >= query_filter.value
204
+ if query_filter.op == FilterOp.LESS:
205
+ return val < query_filter.value
206
+ if query_filter.op == FilterOp.LESS_OR_EQUAL:
207
+ return val <= query_filter.value
208
+ if query_filter.op == FilterOp.EQUAL:
209
+ return val == query_filter.value
210
+ raise ValueError(f"Unknown operator {query_filter.op}")
211
+ return test_func
212
+
213
+ def remove(self, ids: list[str] | str) -> int:
214
+ if isinstance(ids, str):
215
+ ids = [ids]
216
+ removed_elements_ids = []
217
+ for doc_id in ids:
218
+ self._handle_index({self.ID_FIELD: doc_id}, remove=True)
219
+ q = Query()
220
+ removed_elements = self._db.remove(q.id_ == doc_id)
221
+ removed_elements_ids.extend(removed_elements)
222
+ return len(list(set(removed_elements_ids)))
223
+
224
+ def _prepare_doc_for_store(self, doc: dict) -> dict:
225
+ self._ensure_created_at(doc)
226
+ return self._ensure_hash(self._ensure_id(doc))
227
+
228
+ def _prepare_doc_for_return(self, doc: dict) -> dict:
229
+ doc_copy = doc.copy()
230
+ self._ensure_created_at(doc_copy)
231
+ return doc_copy
232
+
233
+ def _ensure_created_at(self, doc: dict):
234
+ if "created_at" not in doc:
235
+ err_msg = f"There is no created_at in doc!\n{pp(doc, print_msg=False)}"
236
+ self._logger.error(err_msg)
237
+ raise ValueError(err_msg)
238
+
239
+ def _ensure_id(self, doc: dict) -> dict:
240
+ if self.ID_FIELD not in doc:
241
+ doc[self.ID_FIELD] = self._generate_id()
242
+ return doc
243
+
244
+ def _generate_id(self) -> str:
245
+ return str(uuid.uuid4())
246
+
247
+ def _ensure_hash(self, doc: dict) -> dict:
248
+ if self.HASH_FIELD in doc and doc[self.HASH_FIELD]:
249
+ self._logger.debug(f"the doc already has the hash property, keeping it as is: {doc[self.HASH_FIELD]}")
250
+ return doc
251
+ doc[self.HASH_FIELD] = self._generate_hash(doc)
252
+ return doc
253
+
254
+ def _generate_hash(self, doc: dict) -> str:
255
+ doc_copy = doc.copy()
256
+ if self.ID_FIELD in doc_copy:
257
+ del doc_copy[self.ID_FIELD]
258
+ if self.HASH_FIELD in doc_copy:
259
+ del doc_copy[self.HASH_FIELD]
260
+ return hash_dict(doc_copy)
@@ -0,0 +1,19 @@
1
+ from .cacher import ICacher, CacherPickle, Cacher, CacheEntryDoesNotExistError, DummyCacher
2
+ from .locations import Locations
3
+ from .logger import get_logger
4
+ from .utils import pp, pp_cast, ff, hash_dict, to_json, truncate
5
+ from .timer import Timer
6
+ from .at_exit_collector import ICollectable, AtExitCollector
7
+
8
+
9
+ __all__ = [
10
+ "ICollectable", "AtExitCollector",
11
+ "ICacher",
12
+ "CacherPickle",
13
+ "Cacher",
14
+ "CacheEntryDoesNotExistError",
15
+ "DummyCacher",
16
+ "get_logger",
17
+ "Locations",
18
+ "pp", "pp_cast", "ff", "hash_dict", "to_json", "truncate",
19
+ "Timer"]