toolchemy 0.2.185__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- toolchemy/__main__.py +9 -0
- toolchemy/ai/clients/__init__.py +20 -0
- toolchemy/ai/clients/common.py +429 -0
- toolchemy/ai/clients/dummy_model_client.py +61 -0
- toolchemy/ai/clients/factory.py +37 -0
- toolchemy/ai/clients/gemini_client.py +48 -0
- toolchemy/ai/clients/ollama_client.py +58 -0
- toolchemy/ai/clients/openai_client.py +76 -0
- toolchemy/ai/clients/pricing.py +66 -0
- toolchemy/ai/clients/whisper_client.py +141 -0
- toolchemy/ai/prompter.py +124 -0
- toolchemy/ai/trackers/__init__.py +5 -0
- toolchemy/ai/trackers/common.py +216 -0
- toolchemy/ai/trackers/mlflow_tracker.py +221 -0
- toolchemy/ai/trackers/neptune_tracker.py +135 -0
- toolchemy/db/lightdb.py +260 -0
- toolchemy/utils/__init__.py +19 -0
- toolchemy/utils/at_exit_collector.py +109 -0
- toolchemy/utils/cacher/__init__.py +20 -0
- toolchemy/utils/cacher/cacher_diskcache.py +121 -0
- toolchemy/utils/cacher/cacher_pickle.py +152 -0
- toolchemy/utils/cacher/cacher_shelve.py +196 -0
- toolchemy/utils/cacher/common.py +174 -0
- toolchemy/utils/datestimes.py +77 -0
- toolchemy/utils/locations.py +111 -0
- toolchemy/utils/logger.py +76 -0
- toolchemy/utils/timer.py +23 -0
- toolchemy/utils/utils.py +168 -0
- toolchemy/vision/__init__.py +5 -0
- toolchemy/vision/caption_overlay.py +77 -0
- toolchemy/vision/image.py +89 -0
- toolchemy-0.2.185.dist-info/METADATA +25 -0
- toolchemy-0.2.185.dist-info/RECORD +36 -0
- toolchemy-0.2.185.dist-info/WHEEL +4 -0
- toolchemy-0.2.185.dist-info/entry_points.txt +3 -0
- toolchemy-0.2.185.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
import hashlib
|
|
3
|
+
import copy
|
|
4
|
+
from abc import abstractmethod
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from toolchemy.utils.at_exit_collector import ICollectable, AtExitCollector
|
|
8
|
+
from toolchemy.utils.datestimes import current_date_str, current_unix_timestamp
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class CacherInitializationError(Exception):
|
|
12
|
+
pass
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class CacheEntryDoesNotExistError(Exception):
|
|
16
|
+
pass
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class CacheEntryHasNotBeenSetError(Exception):
|
|
20
|
+
pass
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class ICacher(abc.ABC):
|
|
24
|
+
CACHER_MAIN_NAME = ".cache"
|
|
25
|
+
|
|
26
|
+
"""
|
|
27
|
+
Cacher interface
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
@abc.abstractmethod
|
|
31
|
+
def sub_cacher(self, log_level: int | None = None, suffix: str | None = None) -> "ICacher":
|
|
32
|
+
pass
|
|
33
|
+
|
|
34
|
+
@abc.abstractmethod
|
|
35
|
+
def exists(self, name: str) -> bool:
|
|
36
|
+
"""
|
|
37
|
+
Checks if there is a cache entry for a given name
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
@abc.abstractmethod
|
|
41
|
+
def set(self, name: str, content: Any, ttl_s: int | None = None):
|
|
42
|
+
"""
|
|
43
|
+
Dumps a given object under a given cache entry name. The object must be pickleable.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
@abc.abstractmethod
|
|
47
|
+
def unset(self, name: str):
|
|
48
|
+
"""
|
|
49
|
+
Removes a cache entry for a given name
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
@abc.abstractmethod
|
|
53
|
+
def get(self, name: str) -> Any:
|
|
54
|
+
"""
|
|
55
|
+
Loads an object for a given cache entry name. If it doesn't exist, an exception is thrown.
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
@property
|
|
59
|
+
@abstractmethod
|
|
60
|
+
def cache_location(self) -> str:
|
|
61
|
+
pass
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class BaseCacher(ICacher, ICollectable, abc.ABC):
|
|
65
|
+
def __init__(self):
|
|
66
|
+
self._name = self.__module__
|
|
67
|
+
self._cache_stats = {
|
|
68
|
+
"hit": 0,
|
|
69
|
+
"miss": 0,
|
|
70
|
+
}
|
|
71
|
+
AtExitCollector.register(self)
|
|
72
|
+
|
|
73
|
+
def collect(self) -> dict:
|
|
74
|
+
return self._cache_stats
|
|
75
|
+
|
|
76
|
+
def label(self) -> str:
|
|
77
|
+
return f"{self.__class__.__name__}({self._name})"
|
|
78
|
+
|
|
79
|
+
def exists(self, name: str) -> bool:
|
|
80
|
+
does_exist = self._exists(name)
|
|
81
|
+
if does_exist:
|
|
82
|
+
self._cache_stats["hit"] += 1
|
|
83
|
+
else:
|
|
84
|
+
self._cache_stats["miss"] += 1
|
|
85
|
+
return does_exist
|
|
86
|
+
|
|
87
|
+
@abc.abstractmethod
|
|
88
|
+
def _exists(self, name: str) -> bool:
|
|
89
|
+
"""
|
|
90
|
+
Checks if there is a cache entry for a given name
|
|
91
|
+
"""
|
|
92
|
+
|
|
93
|
+
def persist(self):
|
|
94
|
+
pass
|
|
95
|
+
|
|
96
|
+
@staticmethod
|
|
97
|
+
def hash(name: str) -> str:
|
|
98
|
+
hash_object = hashlib.md5(name.encode('utf-8'))
|
|
99
|
+
return hash_object.hexdigest()
|
|
100
|
+
|
|
101
|
+
@staticmethod
|
|
102
|
+
def create_cache_key(parts_plain: list | str | None = None, parts_hashed: list | str | None = None,
|
|
103
|
+
with_current_date: bool = False) -> str:
|
|
104
|
+
replaceable_chars = "*.,'\"|<>[]?!-:;()@#$%^&{} "
|
|
105
|
+
if parts_plain is None and parts_hashed is None:
|
|
106
|
+
raise ValueError(f"You must provide the key components")
|
|
107
|
+
if parts_plain is None:
|
|
108
|
+
parts_plain = []
|
|
109
|
+
if parts_hashed is None:
|
|
110
|
+
parts_hashed = []
|
|
111
|
+
if isinstance(parts_plain, str):
|
|
112
|
+
parts_plain = [parts_plain]
|
|
113
|
+
if isinstance(parts_hashed, str):
|
|
114
|
+
parts_hashed = [parts_hashed]
|
|
115
|
+
|
|
116
|
+
for i, part_plain in enumerate(parts_plain):
|
|
117
|
+
for char_to_replace in list(replaceable_chars):
|
|
118
|
+
parts_plain[i] = str(parts_plain[i]).replace(char_to_replace, "_")
|
|
119
|
+
|
|
120
|
+
parts_hashed = [BaseCacher.hash(str(part_hashed)) for part_hashed in parts_hashed]
|
|
121
|
+
parts = parts_plain + parts_hashed
|
|
122
|
+
if with_current_date:
|
|
123
|
+
parts.append(current_date_str("%Y%m%d"))
|
|
124
|
+
|
|
125
|
+
return "_".join(parts)
|
|
126
|
+
|
|
127
|
+
def _envelop(self, content: Any, ttl_s: int | None = None) -> dict[str, Any]:
|
|
128
|
+
if not isinstance(content, dict) or ("data" not in content and "timestamp" not in content and "ttl_s" not in content):
|
|
129
|
+
entry_timestamp = current_unix_timestamp()
|
|
130
|
+
content = {'data': content, 'timestamp': entry_timestamp, 'ttl_s': ttl_s}
|
|
131
|
+
return content
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
class DummyLock:
|
|
135
|
+
def __enter__(self):
|
|
136
|
+
pass
|
|
137
|
+
|
|
138
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
139
|
+
pass
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
class DummyCacher(BaseCacher):
|
|
143
|
+
def __init__(self, with_memory_store: bool = False):
|
|
144
|
+
super().__init__()
|
|
145
|
+
self._data = {}
|
|
146
|
+
self._with_memory_store = with_memory_store
|
|
147
|
+
|
|
148
|
+
@property
|
|
149
|
+
def cache_location(self) -> str:
|
|
150
|
+
return ""
|
|
151
|
+
|
|
152
|
+
def sub_cacher(self, log_level: int | None = None, suffix: str | None = None) -> "ICacher":
|
|
153
|
+
return DummyCacher(with_memory_store=self._with_memory_store)
|
|
154
|
+
|
|
155
|
+
def _exists(self, name: str) -> bool:
|
|
156
|
+
if not self._with_memory_store:
|
|
157
|
+
return False
|
|
158
|
+
return name in self._data
|
|
159
|
+
|
|
160
|
+
def set(self, name: str, content: Any, ttl_s: int | None = None):
|
|
161
|
+
if not self._with_memory_store:
|
|
162
|
+
return
|
|
163
|
+
self._data[name] = copy.deepcopy(content)
|
|
164
|
+
|
|
165
|
+
def unset(self, name: str):
|
|
166
|
+
if name in self._data:
|
|
167
|
+
del self._data[name]
|
|
168
|
+
|
|
169
|
+
def get(self, name: str) -> Any:
|
|
170
|
+
if not self._with_memory_store:
|
|
171
|
+
return None
|
|
172
|
+
if name not in self._data:
|
|
173
|
+
raise CacheEntryDoesNotExistError()
|
|
174
|
+
return self._data[name]
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
import datetime
|
|
2
|
+
import time
|
|
3
|
+
from enum import Enum
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class Seconds(float, Enum):
|
|
7
|
+
NANOSECOND = 1 / 1000000000
|
|
8
|
+
MICROSECOND = 1 / 1000000
|
|
9
|
+
MILLISECOND = 1 / 1000
|
|
10
|
+
SECOND = 1
|
|
11
|
+
MINUTE = 60
|
|
12
|
+
HOUR = 60 * MINUTE
|
|
13
|
+
DAY = 24 * HOUR
|
|
14
|
+
WEEK = 7 * DAY
|
|
15
|
+
MONTH = 30 * DAY
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
DEFAULT_DATE_FORMAT = "%d-%m-%Y"
|
|
19
|
+
DEFAULT_DATETIME_FORMAT = "%d-%m-%Y %H:%M:%S"
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def format_str(datetime_str: str, date_format: str, target_datetime_format: str = DEFAULT_DATETIME_FORMAT) -> str:
|
|
23
|
+
datetime_obj = datetime.datetime.strptime(datetime_str, date_format)
|
|
24
|
+
return datetime_to_str(datetime_obj, datetime_format=target_datetime_format)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def date_to_str(date_obj: datetime.date, date_format: str = DEFAULT_DATE_FORMAT) -> str:
|
|
28
|
+
return date_obj.strftime(date_format)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def datetime_to_str(datetime_obj: datetime.datetime, datetime_format: str = DEFAULT_DATETIME_FORMAT) -> str:
|
|
32
|
+
return datetime_obj.strftime(datetime_format)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def str_to_datetime(datetime_str: str, datetime_format: str = DEFAULT_DATETIME_FORMAT) -> datetime.datetime:
|
|
36
|
+
return datetime.datetime.strptime(datetime_str, datetime_format)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def str_to_date(date_str: str, date_format: str = DEFAULT_DATE_FORMAT) -> datetime.date:
|
|
40
|
+
return datetime.datetime.strptime(date_str, date_format).date()
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def current_date_str(date_format: str = DEFAULT_DATE_FORMAT, time_delta_days: int | None = None) -> str:
|
|
44
|
+
date_ = datetime.date.today()
|
|
45
|
+
if time_delta_days is not None:
|
|
46
|
+
date_ -= datetime.timedelta(days=time_delta_days)
|
|
47
|
+
return date_to_str(date_, date_format)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def current_datetime_str(datetime_format: str = DEFAULT_DATETIME_FORMAT, time_delta_days: int | None = None) -> str:
|
|
51
|
+
datetime_ = datetime.datetime.now()
|
|
52
|
+
if time_delta_days is not None:
|
|
53
|
+
datetime_ -= datetime.timedelta(days=time_delta_days)
|
|
54
|
+
return datetime_to_str(datetime_, datetime_format)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def seconds_to_time_str(seconds: int | float) -> str:
|
|
58
|
+
seconds = int(seconds)
|
|
59
|
+
hours, remainder = divmod(seconds, 3600)
|
|
60
|
+
minutes, seconds = divmod(remainder, 60)
|
|
61
|
+
return f"{hours:02}:{minutes:02}:{seconds:02}"
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def datetime_to_unix_timestamp(datetime_ob: datetime.datetime) -> int:
|
|
65
|
+
return int(datetime_ob.timestamp())
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def unix_timestamp_to_datetime(unix_timestamp: int) -> datetime.datetime:
|
|
69
|
+
return datetime.datetime.fromtimestamp(unix_timestamp)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def str_to_unix_timestamp(datetime_str: str, datetime_format: str = DEFAULT_DATETIME_FORMAT) -> int:
|
|
73
|
+
return datetime_to_unix_timestamp(str_to_datetime(datetime_str, datetime_format))
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def current_unix_timestamp() -> int:
|
|
77
|
+
return int(time.time())
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
import os
|
|
3
|
+
import json
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class PathDoesNotExistError(Exception):
|
|
8
|
+
pass
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _find_project_root(path: Path) -> Path:
|
|
12
|
+
for parent in [path] + list(path.parents):
|
|
13
|
+
if (parent / 'pyproject.toml').exists() or (parent / '.git').exists():
|
|
14
|
+
return parent
|
|
15
|
+
return path
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def get_external_caller_path(exclude_prefixes=None) -> str:
|
|
19
|
+
if exclude_prefixes is None:
|
|
20
|
+
exclude_prefixes = [str(Path(__file__).resolve().parent)]
|
|
21
|
+
|
|
22
|
+
for frame_info in inspect.stack():
|
|
23
|
+
module = inspect.getmodule(frame_info.frame)
|
|
24
|
+
if not module or not hasattr(module, '__file__'):
|
|
25
|
+
continue
|
|
26
|
+
|
|
27
|
+
path = Path(module.__file__).resolve()
|
|
28
|
+
if all(not str(path).startswith(prefix) for prefix in exclude_prefixes):
|
|
29
|
+
project_root_path = str(_find_project_root(path.parents[1]))
|
|
30
|
+
if "site-packages" in project_root_path:
|
|
31
|
+
return str(Path.cwd())
|
|
32
|
+
return project_root_path
|
|
33
|
+
|
|
34
|
+
if exclude_prefixes and len(exclude_prefixes) > 0:
|
|
35
|
+
return str(exclude_prefixes[0])
|
|
36
|
+
raise RuntimeError("Could not find external caller")
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class Locations:
|
|
40
|
+
def __init__(self, prefix_dirs: dict | None = None, root_path: str | None = None, objective_path_mode: bool = False) -> None:
|
|
41
|
+
if root_path is None:
|
|
42
|
+
root_path = get_external_caller_path()
|
|
43
|
+
self._dirs = {
|
|
44
|
+
"root": root_path.rstrip("/"),
|
|
45
|
+
"resources": os.path.join(root_path, "resources").rstrip("/"),
|
|
46
|
+
"data": os.path.join(root_path, "data").rstrip("/"),
|
|
47
|
+
"logs": os.path.join(root_path, "logs").rstrip("/"),
|
|
48
|
+
}
|
|
49
|
+
self._prefix_dirs = prefix_dirs
|
|
50
|
+
self._objective_path_mode = objective_path_mode
|
|
51
|
+
|
|
52
|
+
@property
|
|
53
|
+
def root(self) -> str:
|
|
54
|
+
ret = self._dirs["root"]
|
|
55
|
+
if self._objective_path_mode:
|
|
56
|
+
ret = Path(ret)
|
|
57
|
+
return ret
|
|
58
|
+
|
|
59
|
+
@staticmethod
|
|
60
|
+
def read_content(path: str | Path) -> str:
|
|
61
|
+
if isinstance(path, str):
|
|
62
|
+
path = Path(path)
|
|
63
|
+
return path.read_text()
|
|
64
|
+
|
|
65
|
+
@staticmethod
|
|
66
|
+
def read_json(path: str | Path) -> dict | list:
|
|
67
|
+
content = Locations.read_content(path)
|
|
68
|
+
try:
|
|
69
|
+
json_data = json.loads(content)
|
|
70
|
+
except json.JSONDecodeError as e:
|
|
71
|
+
raise json.JSONDecodeError(f"{e} (Failed to parse JSON from {path}: {e}. Parsed content:\n{content})", e.doc, e.pos) from e
|
|
72
|
+
return json_data
|
|
73
|
+
|
|
74
|
+
@staticmethod
|
|
75
|
+
def save_json(data: dict | list, path: str | Path):
|
|
76
|
+
with open(str(path), "w") as f:
|
|
77
|
+
json.dump(data, f, indent=4)
|
|
78
|
+
|
|
79
|
+
@staticmethod
|
|
80
|
+
def abs(path: str | Path) -> str:
|
|
81
|
+
return os.path.abspath(os.path.expanduser(str(path)))
|
|
82
|
+
|
|
83
|
+
def project_rel(self, path: str | Path) -> str | Path:
|
|
84
|
+
abs_path = str(self.abs(path))
|
|
85
|
+
rel_path = abs_path.replace(self.root, ".")
|
|
86
|
+
if self._objective_path_mode:
|
|
87
|
+
rel_path = Path(rel_path)
|
|
88
|
+
return rel_path
|
|
89
|
+
|
|
90
|
+
def in_root(self, elements: str | list[str] | None = None) -> str | Path:
|
|
91
|
+
return self.in_("root", elements)
|
|
92
|
+
|
|
93
|
+
def in_resources(self, elements: str | list[str] | None = None) -> str | Path:
|
|
94
|
+
return self.in_("resources", elements)
|
|
95
|
+
|
|
96
|
+
def in_data(self, elements: str | list[str] | None = None) -> str | Path:
|
|
97
|
+
return self.in_("data", elements)
|
|
98
|
+
|
|
99
|
+
def in_(self, base_dir: str, elements: str | list[str] | None = None) -> str | Path:
|
|
100
|
+
if isinstance(elements, str):
|
|
101
|
+
elements = [elements]
|
|
102
|
+
if elements is None:
|
|
103
|
+
elements = []
|
|
104
|
+
if base_dir not in self._dirs:
|
|
105
|
+
raise ValueError(f"There is no '{base_dir}' dir defined")
|
|
106
|
+
if self._prefix_dirs and base_dir in self._prefix_dirs:
|
|
107
|
+
elements = [self._prefix_dirs[base_dir]] + elements
|
|
108
|
+
ret = str(os.path.join(self._dirs[base_dir], *elements)).rstrip("/")
|
|
109
|
+
if self._objective_path_mode:
|
|
110
|
+
ret = Path(ret)
|
|
111
|
+
return ret
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import pathlib
|
|
3
|
+
from colorlog import ColoredFormatter
|
|
4
|
+
|
|
5
|
+
from toolchemy.utils.utils import _caller_module_name
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def get_logger(name: str | None = None, level: int = logging.INFO, log_dir: str | None = None, with_time: bool = True,
|
|
9
|
+
with_module_name: bool = True, with_log_level: bool = True, short_module_name: bool = False, say_hi: bool = False) -> logging.Logger:
|
|
10
|
+
name = name or _caller_module_name()
|
|
11
|
+
|
|
12
|
+
datetime_format = "%H:%M:%S"
|
|
13
|
+
prompts_parts = []
|
|
14
|
+
if with_time:
|
|
15
|
+
prompts_parts.append('%(asctime)s')
|
|
16
|
+
if with_module_name:
|
|
17
|
+
module_format = '%(name)s'
|
|
18
|
+
if short_module_name:
|
|
19
|
+
module_format = '%(module)s'
|
|
20
|
+
prompts_parts.append(module_format)
|
|
21
|
+
|
|
22
|
+
if with_log_level:
|
|
23
|
+
prompts_parts.append('%(levelname)s')
|
|
24
|
+
|
|
25
|
+
msg_format = "%(log_color)s|" + " ".join(prompts_parts) + "|%(reset)s %(message)s"
|
|
26
|
+
|
|
27
|
+
formatter = ColoredFormatter(fmt=msg_format, datefmt=datetime_format, force_color=True, log_colors={
|
|
28
|
+
'DEBUG': 'cyan',
|
|
29
|
+
'INFO': 'green',
|
|
30
|
+
'WARNING': 'yellow',
|
|
31
|
+
'ERROR': 'red',
|
|
32
|
+
'CRITICAL': 'red,bg_white',
|
|
33
|
+
},
|
|
34
|
+
secondary_log_colors={})
|
|
35
|
+
|
|
36
|
+
logger = logging.getLogger(name)
|
|
37
|
+
for handler in (logger.handlers or []):
|
|
38
|
+
logger.removeHandler(handler)
|
|
39
|
+
|
|
40
|
+
logger.setLevel(level)
|
|
41
|
+
|
|
42
|
+
if not logger.handlers:
|
|
43
|
+
console_handler = logging.StreamHandler()
|
|
44
|
+
console_handler.setFormatter(formatter)
|
|
45
|
+
logger.addHandler(console_handler)
|
|
46
|
+
|
|
47
|
+
if log_dir is not None:
|
|
48
|
+
log_dir_path = pathlib.Path(log_dir)
|
|
49
|
+
log_dir_path.mkdir(parents=True, exist_ok=True)
|
|
50
|
+
|
|
51
|
+
log_file = log_dir_path / f"{name.replace('.', '_')}.log"
|
|
52
|
+
file_handler = logging.FileHandler(log_file, mode='a', encoding='utf-8')
|
|
53
|
+
file_handler.setFormatter(formatter)
|
|
54
|
+
logger.addHandler(file_handler)
|
|
55
|
+
|
|
56
|
+
logger.propagate = False
|
|
57
|
+
|
|
58
|
+
if say_hi:
|
|
59
|
+
logger.info("Hi:)")
|
|
60
|
+
logger.debug("Debug mode ON")
|
|
61
|
+
logger.debug("All handlers:")
|
|
62
|
+
for handler in logger.handlers:
|
|
63
|
+
logger.debug(f"- name: {handler.get_name()}, format: {handler.formatter._fmt}, level: {handler.level}")
|
|
64
|
+
|
|
65
|
+
return logger
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def testing():
|
|
69
|
+
logger = get_logger("toolchemy.utils.logger", level=logging.DEBUG, say_hi=True)
|
|
70
|
+
logger.warning("Testing logger setup WARNING")
|
|
71
|
+
logger.info("Testing logger setup INFO")
|
|
72
|
+
logger.debug("Testing logger setup DEBUG")
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
if __name__ == "__main__":
|
|
76
|
+
testing()
|
toolchemy/utils/timer.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
import time
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class Timer:
|
|
5
|
+
def __init__(self, started: bool = True):
|
|
6
|
+
self._started = started
|
|
7
|
+
self._started_time = None
|
|
8
|
+
self._last_tap_time = None
|
|
9
|
+
self.reset()
|
|
10
|
+
|
|
11
|
+
def reset(self):
|
|
12
|
+
self._started_time = time.time()
|
|
13
|
+
self._last_tap_time = time.time()
|
|
14
|
+
|
|
15
|
+
def tap(self, since_last: bool = False) -> float:
|
|
16
|
+
now_time = time.time()
|
|
17
|
+
since_time = self._started_time
|
|
18
|
+
if since_last and self._last_tap_time:
|
|
19
|
+
since_time = self._last_tap_time
|
|
20
|
+
|
|
21
|
+
self._last_tap_time = now_time
|
|
22
|
+
|
|
23
|
+
return now_time - since_time
|
toolchemy/utils/utils.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import datetime
|
|
3
|
+
import inspect
|
|
4
|
+
import json
|
|
5
|
+
import numpy as np
|
|
6
|
+
import random
|
|
7
|
+
import torch
|
|
8
|
+
import hashlib
|
|
9
|
+
import base64
|
|
10
|
+
|
|
11
|
+
from dataclasses import asdict, is_dataclass
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
from toolchemy.utils.datestimes import datetime_to_str
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
DEFAULT_SEED = 1337
|
|
18
|
+
|
|
19
|
+
MEGABYTE = 1024 ** 2
|
|
20
|
+
GIGABYTE = 1024 ** 3
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def seed_init_fn(x, only_deterministic: bool = False):
|
|
24
|
+
seed = DEFAULT_SEED + x
|
|
25
|
+
np.random.seed(seed)
|
|
26
|
+
random.seed(seed)
|
|
27
|
+
torch.manual_seed(seed)
|
|
28
|
+
if only_deterministic:
|
|
29
|
+
torch.use_deterministic_algorithms(True)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def bytes_to_str(byte_data: bytes) -> str:
|
|
33
|
+
for encoding in ['utf-8', 'latin-1', 'ascii']:
|
|
34
|
+
try:
|
|
35
|
+
return byte_data.decode(encoding)
|
|
36
|
+
except UnicodeDecodeError:
|
|
37
|
+
continue
|
|
38
|
+
raise ValueError("Unknown encoding")
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def pp_cast(msg: Any, skip_fields: list | None = None) -> Any:
|
|
42
|
+
msg_copy = copy.deepcopy(msg)
|
|
43
|
+
if is_dataclass(msg_copy):
|
|
44
|
+
msg_copy = to_json(msg_copy)
|
|
45
|
+
if isinstance(msg_copy, bytes):
|
|
46
|
+
msg_copy = bytes_to_str(msg_copy)
|
|
47
|
+
if isinstance(msg_copy, dict):
|
|
48
|
+
for key in list(msg_copy):
|
|
49
|
+
if skip_fields and key in skip_fields:
|
|
50
|
+
del msg_copy[key]
|
|
51
|
+
continue
|
|
52
|
+
msg_copy[key] = pp_cast(msg_copy[key], skip_fields=skip_fields)
|
|
53
|
+
if isinstance(msg_copy, list):
|
|
54
|
+
for i, el in enumerate(msg_copy):
|
|
55
|
+
msg_copy[i] = pp_cast(el, skip_fields=skip_fields)
|
|
56
|
+
if isinstance(msg_copy, np.ndarray):
|
|
57
|
+
msg_copy = pp_cast(msg_copy.tolist(), skip_fields=skip_fields)
|
|
58
|
+
if isinstance(msg_copy, float):
|
|
59
|
+
msg_copy = ff(msg)
|
|
60
|
+
if isinstance(msg_copy, datetime.datetime):
|
|
61
|
+
msg_copy = datetime_to_str(msg_copy)
|
|
62
|
+
if isinstance(msg_copy, object) and type(msg_copy).__module__ != "builtins":
|
|
63
|
+
msg_copy = json.loads(json.dumps(msg_copy, default=vars))
|
|
64
|
+
return msg_copy
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def pp(msg: str | dict | float | int | list, skip_fields: list | None = None, print_msg: bool = True) -> str:
|
|
68
|
+
msg_ = pp_cast(msg, skip_fields=skip_fields)
|
|
69
|
+
if isinstance(msg_, dict):
|
|
70
|
+
msg_ = json.dumps(msg_, indent=4, ensure_ascii=False)
|
|
71
|
+
if isinstance(msg_, list):
|
|
72
|
+
if len(msg_) > 0 and isinstance(msg_[0], dict):
|
|
73
|
+
msg_ = json.dumps(msg_, indent=4, ensure_ascii=False)
|
|
74
|
+
if isinstance(msg_, int) or isinstance(msg_, float):
|
|
75
|
+
msg_ = ff(msg_)
|
|
76
|
+
if print_msg:
|
|
77
|
+
print(msg_)
|
|
78
|
+
return msg_
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def ff(fval: float | list[float] | int | list[int] | dict | str | np.float32, precision: int = 2):
|
|
83
|
+
if isinstance(fval, float):
|
|
84
|
+
return "%0.*f" % (precision, fval)
|
|
85
|
+
if isinstance(fval, int):
|
|
86
|
+
return str(fval)
|
|
87
|
+
if isinstance(fval, list):
|
|
88
|
+
return [ff(v, precision=precision) for v in fval]
|
|
89
|
+
if isinstance(fval, dict):
|
|
90
|
+
return {k: ff(v, precision=precision) for k, v in fval.items()}
|
|
91
|
+
if isinstance(fval, str):
|
|
92
|
+
return ff(float(fval), precision)
|
|
93
|
+
if isinstance(fval, np.float32):
|
|
94
|
+
return ff(fval.item(), precision)
|
|
95
|
+
raise ValueError(f"Unsupported type: {type(fval)}")
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def to_json(dataclass_container, key_prefix: str = None, exclude: list[str] | None = None) -> dict:
|
|
99
|
+
if exclude is None:
|
|
100
|
+
exclude = []
|
|
101
|
+
data_dict = asdict(dataclass_container)
|
|
102
|
+
|
|
103
|
+
parsed_dict = {}
|
|
104
|
+
for k, v in data_dict.items():
|
|
105
|
+
if k in exclude:
|
|
106
|
+
continue
|
|
107
|
+
new_key = k
|
|
108
|
+
if key_prefix is not None:
|
|
109
|
+
new_key = f"{key_prefix}_{k}"
|
|
110
|
+
if v is None:
|
|
111
|
+
v = ""
|
|
112
|
+
if isinstance(v, bool):
|
|
113
|
+
v = int(v)
|
|
114
|
+
if isinstance(v, list) and len(v) > 0 and isinstance(v[0], str):
|
|
115
|
+
v = ', '.join(v)
|
|
116
|
+
if is_dataclass(v):
|
|
117
|
+
v = to_json(v, key_prefix=key_prefix, exclude=exclude)
|
|
118
|
+
parsed_dict[new_key] = v
|
|
119
|
+
|
|
120
|
+
return parsed_dict
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def hash_dict(input_dict: dict) -> str:
|
|
124
|
+
json_str = json.dumps(input_dict, sort_keys=True)
|
|
125
|
+
json_bytes = json_str.encode('utf-8')
|
|
126
|
+
hash_bytes = hashlib.sha256(json_bytes).digest()
|
|
127
|
+
base64_hash = base64.b64encode(hash_bytes).decode('utf-8')
|
|
128
|
+
|
|
129
|
+
return base64_hash
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def normalize_path_str(path_: str) -> str:
|
|
133
|
+
return path_.\
|
|
134
|
+
replace("./", "_").\
|
|
135
|
+
replace("~/", ""). \
|
|
136
|
+
replace("~", ""). \
|
|
137
|
+
replace("/", "_").\
|
|
138
|
+
replace("-", "").\
|
|
139
|
+
replace(":", "_").\
|
|
140
|
+
replace("?", "_").\
|
|
141
|
+
replace("&", "_")
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def split_text(text: str, chunk_size: int, chunk_overlap: int) -> list[str]:
|
|
145
|
+
num_chunks = (len(text) - chunk_overlap) // (chunk_size - chunk_overlap)
|
|
146
|
+
|
|
147
|
+
chunks = []
|
|
148
|
+
for i in range(num_chunks):
|
|
149
|
+
start = i * (chunk_size - chunk_overlap)
|
|
150
|
+
end = start + chunk_size
|
|
151
|
+
chunks.append(text[start:end])
|
|
152
|
+
|
|
153
|
+
chunks.append(text[num_chunks * (chunk_size - chunk_overlap):])
|
|
154
|
+
|
|
155
|
+
return chunks
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def truncate(s: str, limit: int) -> str:
|
|
159
|
+
if len(s) <= limit:
|
|
160
|
+
return s
|
|
161
|
+
return f"{s[:limit]} (...{str(len(s) - limit)} more chars)"
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def _caller_module_name(offset: int = 2) -> str:
|
|
165
|
+
frame = inspect.stack()[offset]
|
|
166
|
+
module = inspect.getmodule(frame.frame)
|
|
167
|
+
namespace = module.__name__ if module and hasattr(module, '__name__') else '__main__'
|
|
168
|
+
return namespace
|