toolchemy 0.2.185__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- toolchemy/__main__.py +9 -0
- toolchemy/ai/clients/__init__.py +20 -0
- toolchemy/ai/clients/common.py +429 -0
- toolchemy/ai/clients/dummy_model_client.py +61 -0
- toolchemy/ai/clients/factory.py +37 -0
- toolchemy/ai/clients/gemini_client.py +48 -0
- toolchemy/ai/clients/ollama_client.py +58 -0
- toolchemy/ai/clients/openai_client.py +76 -0
- toolchemy/ai/clients/pricing.py +66 -0
- toolchemy/ai/clients/whisper_client.py +141 -0
- toolchemy/ai/prompter.py +124 -0
- toolchemy/ai/trackers/__init__.py +5 -0
- toolchemy/ai/trackers/common.py +216 -0
- toolchemy/ai/trackers/mlflow_tracker.py +221 -0
- toolchemy/ai/trackers/neptune_tracker.py +135 -0
- toolchemy/db/lightdb.py +260 -0
- toolchemy/utils/__init__.py +19 -0
- toolchemy/utils/at_exit_collector.py +109 -0
- toolchemy/utils/cacher/__init__.py +20 -0
- toolchemy/utils/cacher/cacher_diskcache.py +121 -0
- toolchemy/utils/cacher/cacher_pickle.py +152 -0
- toolchemy/utils/cacher/cacher_shelve.py +196 -0
- toolchemy/utils/cacher/common.py +174 -0
- toolchemy/utils/datestimes.py +77 -0
- toolchemy/utils/locations.py +111 -0
- toolchemy/utils/logger.py +76 -0
- toolchemy/utils/timer.py +23 -0
- toolchemy/utils/utils.py +168 -0
- toolchemy/vision/__init__.py +5 -0
- toolchemy/vision/caption_overlay.py +77 -0
- toolchemy/vision/image.py +89 -0
- toolchemy-0.2.185.dist-info/METADATA +25 -0
- toolchemy-0.2.185.dist-info/RECORD +36 -0
- toolchemy-0.2.185.dist-info/WHEEL +4 -0
- toolchemy-0.2.185.dist-info/entry_points.txt +3 -0
- toolchemy-0.2.185.dist-info/licenses/LICENSE +21 -0
toolchemy/__main__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from .common import ILLMClient, LLMClientBase, ChatMessage, ChatMessages, ModelConfig, ModelResponseError, LLMCacheDoesNotExist, prepare_chat_messages
|
|
2
|
+
from .ollama_client import OllamaClient
|
|
3
|
+
from .openai_client import OpenAIClient
|
|
4
|
+
from .gemini_client import GeminiClient
|
|
5
|
+
from .dummy_model_client import DummyModelClient
|
|
6
|
+
from .factory import create_llm
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"ILLMClient", "LLMClientBase",
|
|
11
|
+
"OllamaClient",
|
|
12
|
+
"OpenAIClient",
|
|
13
|
+
"GeminiClient",
|
|
14
|
+
"DummyModelClient",
|
|
15
|
+
"create_llm",
|
|
16
|
+
"ModelConfig",
|
|
17
|
+
"ModelResponseError",
|
|
18
|
+
"LLMCacheDoesNotExist",
|
|
19
|
+
"prepare_chat_messages",
|
|
20
|
+
"ChatMessage", "ChatMessages"]
|
|
@@ -0,0 +1,429 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import json
|
|
3
|
+
|
|
4
|
+
from jsonschema import validate, ValidationError
|
|
5
|
+
from tenacity import wait_exponential, Retrying, before_sleep_log, stop_after_attempt
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from json import JSONDecodeError
|
|
8
|
+
from json.decoder import JSONDecodeError as JSONDecoderDecodeError
|
|
9
|
+
from abc import ABC, abstractmethod
|
|
10
|
+
from pydantic import BaseModel
|
|
11
|
+
from typing import TypedDict, NotRequired
|
|
12
|
+
|
|
13
|
+
from toolchemy.utils.logger import get_logger
|
|
14
|
+
from toolchemy.utils.utils import ff, truncate
|
|
15
|
+
from toolchemy.utils.cacher import Cacher, DummyCacher, ICacher
|
|
16
|
+
from toolchemy.utils.at_exit_collector import ICollectable, AtExitCollector
|
|
17
|
+
from toolchemy.ai.clients.pricing import Pricing
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class LLMCacheDoesNotExist(Exception):
|
|
21
|
+
pass
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class ModelResponseError(Exception):
|
|
25
|
+
pass
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class ModelConfig(BaseModel):
|
|
29
|
+
model_name: str | None = None
|
|
30
|
+
|
|
31
|
+
max_new_tokens: int | None = 2000
|
|
32
|
+
presence_penalty: float | None = 0.0
|
|
33
|
+
|
|
34
|
+
temperature: float = 0.7
|
|
35
|
+
top_p: float = 1.0
|
|
36
|
+
|
|
37
|
+
def __repr__(self):
|
|
38
|
+
return str(self)
|
|
39
|
+
|
|
40
|
+
def __str__(self):
|
|
41
|
+
return f"{self.model_name}__{self.max_new_tokens}__{ff(self.presence_penalty)}__{ff(self.temperature)}__{ff(self.top_p)}"
|
|
42
|
+
|
|
43
|
+
@classmethod
|
|
44
|
+
def from_raw(cls, data: dict) -> "ModelConfig":
|
|
45
|
+
return cls(model_name=data["model_name"], max_new_tokens=int(data["max_new_tokens"]),
|
|
46
|
+
presence_penalty=data["presence_penalty"], temperature=data["temperature"], top_p=data["top_p"])
|
|
47
|
+
|
|
48
|
+
def raw(self) -> dict:
|
|
49
|
+
return {
|
|
50
|
+
"model_name": self.model_name,
|
|
51
|
+
"max_new_tokens": self.max_new_tokens,
|
|
52
|
+
"presence_penalty": self.presence_penalty,
|
|
53
|
+
"temperature": self.temperature,
|
|
54
|
+
"top_p": self.top_p,
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@dataclass
|
|
59
|
+
class Usage:
|
|
60
|
+
input_tokens: int
|
|
61
|
+
output_tokens: int
|
|
62
|
+
duration: float
|
|
63
|
+
cached: bool = False
|
|
64
|
+
|
|
65
|
+
def __eq__(self, other: "Usage"):
|
|
66
|
+
return other.input_tokens == self.input_tokens and other.output_tokens == self.output_tokens and other.duration == self.duration and other.duration == self.duration
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class ILLMClient(ABC):
|
|
70
|
+
@abstractmethod
|
|
71
|
+
def name(self) -> str:
|
|
72
|
+
pass
|
|
73
|
+
|
|
74
|
+
@property
|
|
75
|
+
@abstractmethod
|
|
76
|
+
def metadata(self) -> dict:
|
|
77
|
+
pass
|
|
78
|
+
|
|
79
|
+
@property
|
|
80
|
+
@abstractmethod
|
|
81
|
+
def system_prompt(self) -> str:
|
|
82
|
+
pass
|
|
83
|
+
|
|
84
|
+
@abstractmethod
|
|
85
|
+
def model_config(self, base_config: ModelConfig | None = None, default_model_name: str | None = None) -> ModelConfig:
|
|
86
|
+
pass
|
|
87
|
+
|
|
88
|
+
@abstractmethod
|
|
89
|
+
def completion(self, prompt: str, model_config: ModelConfig | None = None, images_base64: list[str] | None = None,
|
|
90
|
+
no_cache: bool = False, cache_only: bool = False) -> str:
|
|
91
|
+
pass
|
|
92
|
+
|
|
93
|
+
@abstractmethod
|
|
94
|
+
def completion_json(self, prompt: str, model_config: ModelConfig | None = None, images_base64: list[str] | None = None, validation_schema: dict | None = None,
|
|
95
|
+
no_cache: bool = False, cache_only: bool = False) -> dict | list[dict]:
|
|
96
|
+
pass
|
|
97
|
+
|
|
98
|
+
@abstractmethod
|
|
99
|
+
def invalidate_completion_cache(self, prompt: str, model_config: ModelConfig | None = None, images_base64: list[str] | None = None):
|
|
100
|
+
pass
|
|
101
|
+
|
|
102
|
+
@abstractmethod
|
|
103
|
+
def embeddings(self, text: str) -> list[float]:
|
|
104
|
+
pass
|
|
105
|
+
|
|
106
|
+
@property
|
|
107
|
+
@abstractmethod
|
|
108
|
+
def usage_summary(self) -> dict:
|
|
109
|
+
pass
|
|
110
|
+
|
|
111
|
+
@abstractmethod
|
|
112
|
+
def usage(self, tail: int | None = None) -> list[Usage]:
|
|
113
|
+
pass
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class ChatMessageContent(TypedDict):
|
|
117
|
+
type: str
|
|
118
|
+
text: NotRequired[str]
|
|
119
|
+
image_url: NotRequired[str]
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class ChatMessage(TypedDict):
|
|
123
|
+
role: str
|
|
124
|
+
content: str | list[ChatMessageContent]
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
class ChatMessages(TypedDict):
|
|
128
|
+
messages: list[ChatMessage]
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class LLMClientBase(ILLMClient, ICollectable, ABC):
|
|
132
|
+
def __init__(self, default_model_name: str | None = None, default_embedding_model_name: str | None = None,
|
|
133
|
+
default_model_config: ModelConfig | None = None,
|
|
134
|
+
system_prompt: str | None = None, keep_chat_session: bool = False,
|
|
135
|
+
retry_attempts: int = 5, retry_min_wait: int = 2, retry_max_wait: int = 60,
|
|
136
|
+
truncate_log_messages_to: int = 200,
|
|
137
|
+
fix_malformed_json: bool = True,
|
|
138
|
+
cacher: ICacher | None = None, disable_cache: bool = False, log_level: int = logging.INFO):
|
|
139
|
+
self._logger = get_logger(__name__, level=log_level)
|
|
140
|
+
if disable_cache:
|
|
141
|
+
self._cacher = DummyCacher()
|
|
142
|
+
else:
|
|
143
|
+
if cacher is None:
|
|
144
|
+
self._cacher = Cacher(log_level=log_level)
|
|
145
|
+
else:
|
|
146
|
+
self._cacher = cacher.sub_cacher(log_level=log_level)
|
|
147
|
+
logging.getLogger("httpx").setLevel(logging.WARNING)
|
|
148
|
+
self._model_name = default_model_name
|
|
149
|
+
self._metadata = {"name": self._model_name}
|
|
150
|
+
self._embedding_model_name = default_embedding_model_name
|
|
151
|
+
self._default_model_config = default_model_config
|
|
152
|
+
if self._default_model_config is None:
|
|
153
|
+
self._default_model_config = ModelConfig()
|
|
154
|
+
self._session_messages = []
|
|
155
|
+
self._system_prompt = system_prompt
|
|
156
|
+
self._keep_chat_session = keep_chat_session
|
|
157
|
+
self._usages = []
|
|
158
|
+
self._truncate_log_messages_to = truncate_log_messages_to
|
|
159
|
+
self._retryer = Retrying(stop=stop_after_attempt(retry_attempts),
|
|
160
|
+
wait=wait_exponential(multiplier=1, min=retry_min_wait, max=retry_max_wait),
|
|
161
|
+
before_sleep=before_sleep_log(self._logger, log_level=log_level))
|
|
162
|
+
self._fix_malformed_json = fix_malformed_json
|
|
163
|
+
self._prompter = None
|
|
164
|
+
if self._fix_malformed_json:
|
|
165
|
+
self._fix_json_prompt_template = """Below is a malformed JSON object. Your task is to fix it to be a valid JSON, preserving all the data it already contains.
|
|
166
|
+
You must return only the fixed JSON object, no additional comments or explanations, just a fixed valid JSON!
|
|
167
|
+
|
|
168
|
+
Malformed JSON object:
|
|
169
|
+
{json_object}"""
|
|
170
|
+
AtExitCollector.register(self)
|
|
171
|
+
|
|
172
|
+
def name(self) -> str:
|
|
173
|
+
model_config = self.model_config(default_model_name=self._model_name)
|
|
174
|
+
return model_config.model_name
|
|
175
|
+
|
|
176
|
+
@property
|
|
177
|
+
def system_prompt(self) -> str:
|
|
178
|
+
return self._system_prompt
|
|
179
|
+
|
|
180
|
+
@property
|
|
181
|
+
def metadata(self) -> dict:
|
|
182
|
+
return self._metadata
|
|
183
|
+
|
|
184
|
+
@property
|
|
185
|
+
def usage_summary(self) -> dict:
|
|
186
|
+
input_tokens = sum([usage.input_tokens for usage in self._usages if not usage.cached])
|
|
187
|
+
output_tokens = sum([usage.output_tokens for usage in self._usages if not usage.cached])
|
|
188
|
+
cached_input_tokens = sum([usage.input_tokens for usage in self._usages if usage.cached])
|
|
189
|
+
cached_output_tokens = sum([usage.output_tokens for usage in self._usages if usage.cached])
|
|
190
|
+
total_usage = {
|
|
191
|
+
"request_count": len([usage for usage in self._usages if not usage.cached]),
|
|
192
|
+
"input_tokens": input_tokens,
|
|
193
|
+
"output_tokens": output_tokens,
|
|
194
|
+
"total_tokens": 0,
|
|
195
|
+
"duration": sum([round(usage.duration) for usage in self._usages if not usage.cached]),
|
|
196
|
+
"duration_avg": 0.0,
|
|
197
|
+
"cost": Pricing.estimate(self.name(), input_tokens=input_tokens, output_tokens=output_tokens),
|
|
198
|
+
"cached_request_count": len([usage for usage in self._usages if usage.cached]),
|
|
199
|
+
"cached_input_tokens": cached_input_tokens,
|
|
200
|
+
"cached_output_tokens": cached_output_tokens,
|
|
201
|
+
"cached_total_tokens": 0,
|
|
202
|
+
"cached_duration": sum([round(usage.duration) for usage in self._usages if usage.cached]),
|
|
203
|
+
"cached_duration_avg": 0.0,
|
|
204
|
+
"cached_cost": Pricing.estimate(self.name(), input_tokens=cached_input_tokens, output_tokens=cached_output_tokens)
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
total_usage["total_tokens"] = total_usage["input_tokens"] + total_usage["output_tokens"]
|
|
208
|
+
total_usage["cached_total_tokens"] = total_usage["cached_input_tokens"] + total_usage["cached_output_tokens"]
|
|
209
|
+
total_usage["duration_avg"] = float(total_usage["duration"] / total_usage["request_count"]) if total_usage["request_count"] else 0.0
|
|
210
|
+
total_usage["cached_duration_avg"] = float(total_usage["cached_duration"] / total_usage["cached_request_count"]) if total_usage["cached_request_count"] else 0.0
|
|
211
|
+
|
|
212
|
+
return total_usage
|
|
213
|
+
|
|
214
|
+
def collect(self) -> dict:
|
|
215
|
+
return self.usage_summary
|
|
216
|
+
|
|
217
|
+
def label(self) -> str:
|
|
218
|
+
return f"{self.__class__.__name__}({self.name()})"
|
|
219
|
+
|
|
220
|
+
def usage(self, tail: int | None = None) -> list[Usage]:
|
|
221
|
+
tail = tail or len(self._usages)
|
|
222
|
+
usages = self._usages[-tail:]
|
|
223
|
+
return usages
|
|
224
|
+
|
|
225
|
+
@property
|
|
226
|
+
def embedding_name(self) -> str:
|
|
227
|
+
return self._embedding_model_name
|
|
228
|
+
|
|
229
|
+
def model_config(self, base_config: ModelConfig | None = None, default_model_name: str | None = None) -> ModelConfig:
|
|
230
|
+
if base_config is None:
|
|
231
|
+
base_config = self._default_model_config.model_copy()
|
|
232
|
+
if base_config.model_name is None:
|
|
233
|
+
if default_model_name is None:
|
|
234
|
+
raise RuntimeError(f"Model name or default model must be set")
|
|
235
|
+
base_config.model_name = self._model_name
|
|
236
|
+
return base_config
|
|
237
|
+
|
|
238
|
+
@abstractmethod
|
|
239
|
+
def _completion(self, prompt: str, system_prompt: str | None, model_config: ModelConfig | None = None,
|
|
240
|
+
images_base64: list[str] | None = None) -> tuple[str, Usage]:
|
|
241
|
+
pass
|
|
242
|
+
|
|
243
|
+
def invalidate_completion_cache(self, prompt: str, model_config: ModelConfig | None = None, images_base64: list[str] | None = None):
|
|
244
|
+
model_config = self.model_config(model_config, self._model_name)
|
|
245
|
+
cache_key, cache_key_usage = self._cache_keys_completion(system_prompt=self._system_prompt, prompt=prompt,
|
|
246
|
+
model_config=model_config, images_base64=images_base64,
|
|
247
|
+
is_json=False)
|
|
248
|
+
cache_key_json, cache_key_usage_json = self._cache_keys_completion(system_prompt=self._system_prompt, prompt=prompt,
|
|
249
|
+
model_config=model_config, images_base64=images_base64,
|
|
250
|
+
is_json=True)
|
|
251
|
+
|
|
252
|
+
self._cacher.unset(cache_key)
|
|
253
|
+
self._cacher.unset(cache_key_usage)
|
|
254
|
+
self._cacher.unset(cache_key_json)
|
|
255
|
+
self._cacher.unset(cache_key_usage_json)
|
|
256
|
+
|
|
257
|
+
def completion_json(self, prompt: str, model_config: ModelConfig | None = None,
|
|
258
|
+
images_base64: list[str] | None = None, validation_schema: dict | None = None,
|
|
259
|
+
no_cache: bool = False, cache_only: bool = False) -> dict | list[dict]:
|
|
260
|
+
model_cfg = self.model_config(model_config, self._model_name)
|
|
261
|
+
self._logger.debug(f"CompletionJSON started (model: '{model_cfg.model_name}', max_len: {model_cfg.max_new_tokens}, temp: {model_cfg.max_new_tokens}), top_p: {model_cfg.top_p})")
|
|
262
|
+
self._logger.debug(f"> Model config (mod): model: {model_cfg.model_name}, max_new_tokens: {model_cfg.max_new_tokens}, temp: {model_cfg.temperature}")
|
|
263
|
+
|
|
264
|
+
cache_key, cache_key_usage = self._cache_keys_completion(system_prompt=self._system_prompt, prompt=prompt,
|
|
265
|
+
model_config=model_cfg, images_base64=images_base64,
|
|
266
|
+
is_json=True)
|
|
267
|
+
if not no_cache and self._cacher.exists(cache_key) and self._cacher.exists(cache_key_usage):
|
|
268
|
+
self._logger.debug(f"Cache for completion_json already exists ('{cache_key}')")
|
|
269
|
+
usage = self._cacher.get(cache_key_usage)
|
|
270
|
+
usage.cached = True
|
|
271
|
+
self._usages.append(usage)
|
|
272
|
+
return self._cacher.get(cache_key)
|
|
273
|
+
|
|
274
|
+
if cache_only:
|
|
275
|
+
raise LLMCacheDoesNotExist()
|
|
276
|
+
|
|
277
|
+
self._logger.debug(f"Cache for completion_json does not exists, generating new response")
|
|
278
|
+
|
|
279
|
+
response_json, usage = self._retryer(self._completion_json, prompt=prompt, system_prompt=self._system_prompt,
|
|
280
|
+
model_config=model_cfg,
|
|
281
|
+
images_base64=images_base64,
|
|
282
|
+
validation_schema=validation_schema)
|
|
283
|
+
self._usages.append(usage)
|
|
284
|
+
|
|
285
|
+
if not no_cache:
|
|
286
|
+
self._cacher.set(cache_key, response_json)
|
|
287
|
+
self._cacher.set(cache_key_usage, usage)
|
|
288
|
+
|
|
289
|
+
return response_json
|
|
290
|
+
|
|
291
|
+
def _completion_json(self, prompt: str, system_prompt: str, model_config: ModelConfig, images_base64: list[str] | None,
|
|
292
|
+
validation_schema: dict | None = None) -> tuple[dict | list[dict], Usage]:
|
|
293
|
+
response_str, usage = self._completion(prompt=prompt, system_prompt=system_prompt, model_config=model_config,
|
|
294
|
+
images_base64=images_base64)
|
|
295
|
+
|
|
296
|
+
response_json = None
|
|
297
|
+
|
|
298
|
+
response_str = response_str.replace("```json", "").replace("```", "").strip()
|
|
299
|
+
|
|
300
|
+
try:
|
|
301
|
+
response_json = self._decode_json(response_str)
|
|
302
|
+
if validation_schema:
|
|
303
|
+
try:
|
|
304
|
+
validate(instance=response_json, schema=validation_schema)
|
|
305
|
+
except ValidationError as e:
|
|
306
|
+
self._logger.error(f"Invalid schema: {e}")
|
|
307
|
+
raise e
|
|
308
|
+
|
|
309
|
+
except (JSONDecodeError, JSONDecoderDecodeError) as e:
|
|
310
|
+
if self._fix_malformed_json and self._fix_json_prompt_template:
|
|
311
|
+
self._logger.warning("Malformed JSON, trying to fix it...")
|
|
312
|
+
self._logger.warning(f"Malformed JSON:\n'{response_str}'")
|
|
313
|
+
fix_json_prompt = self._fix_json_prompt_template.format(json_object=response_str)
|
|
314
|
+
response_json = self.completion_json(fix_json_prompt)
|
|
315
|
+
self._logger.debug(f"Fixed JSON:\n{response_json}")
|
|
316
|
+
if response_json is None:
|
|
317
|
+
self._logger.error(f"Invalid JSON:\n{truncate(response_str, self._truncate_log_messages_to)}\n")
|
|
318
|
+
self._logger.error(f"> prompt:\n{truncate(prompt, self._truncate_log_messages_to)}")
|
|
319
|
+
raise e
|
|
320
|
+
|
|
321
|
+
return response_json, usage
|
|
322
|
+
|
|
323
|
+
@staticmethod
|
|
324
|
+
def _decode_json(json_str: str) -> dict | list[dict]:
|
|
325
|
+
try:
|
|
326
|
+
content = json.loads(json_str)
|
|
327
|
+
return content
|
|
328
|
+
except (JSONDecodeError, JSONDecodeError) as e:
|
|
329
|
+
pass
|
|
330
|
+
|
|
331
|
+
lines = [line.strip() for line in json_str.strip().split('\n') if line.strip()]
|
|
332
|
+
parsed_data = []
|
|
333
|
+
|
|
334
|
+
starting_line = 0
|
|
335
|
+
if len(lines[0]) == 1 and lines[0][0] == "l":
|
|
336
|
+
starting_line = 1
|
|
337
|
+
|
|
338
|
+
for i, line in enumerate(lines):
|
|
339
|
+
if starting_line == 1 and i == 0:
|
|
340
|
+
continue
|
|
341
|
+
parsed_data.append(json.loads(line))
|
|
342
|
+
|
|
343
|
+
if len(parsed_data) == 1 and starting_line == 0:
|
|
344
|
+
parsed_data = parsed_data[0]
|
|
345
|
+
|
|
346
|
+
return parsed_data
|
|
347
|
+
|
|
348
|
+
def completion(self, prompt: str, model_config: ModelConfig | None = None,
|
|
349
|
+
images_base64: list[str] | None = None, no_cache: bool = False, cache_only: bool = False) -> str:
|
|
350
|
+
model_config = self.model_config(model_config, self._model_name)
|
|
351
|
+
self._logger.debug(f"Completion started (model: {model_config.model_name})")
|
|
352
|
+
|
|
353
|
+
cache_key, cache_key_usage = self._cache_keys_completion(system_prompt=self._system_prompt, prompt=prompt,
|
|
354
|
+
model_config=model_config, images_base64=images_base64,
|
|
355
|
+
is_json=False)
|
|
356
|
+
if not no_cache and self._cacher.exists(cache_key) and self._cacher.exists(cache_key_usage):
|
|
357
|
+
self._logger.debug(f"Cache for the prompt already exists ('{cache_key}')")
|
|
358
|
+
usage_cached = self._cacher.get(cache_key_usage)
|
|
359
|
+
usage_cached.cached = True
|
|
360
|
+
self._usages.append(usage_cached)
|
|
361
|
+
return self._cacher.get(cache_key)
|
|
362
|
+
|
|
363
|
+
if cache_only:
|
|
364
|
+
raise LLMCacheDoesNotExist()
|
|
365
|
+
|
|
366
|
+
response, usage = self._retryer(self._completion, prompt=prompt, system_prompt=self._system_prompt,
|
|
367
|
+
model_config=model_config,
|
|
368
|
+
images_base64=images_base64)
|
|
369
|
+
self._usages.append(usage)
|
|
370
|
+
|
|
371
|
+
if not no_cache:
|
|
372
|
+
self._cacher.set(cache_key, response)
|
|
373
|
+
self._cacher.set(cache_key_usage, usage)
|
|
374
|
+
|
|
375
|
+
self._logger.debug(f"Completion done.")
|
|
376
|
+
|
|
377
|
+
return response
|
|
378
|
+
|
|
379
|
+
def _cache_keys_completion(self, system_prompt: str, prompt: str, model_config: ModelConfig,
|
|
380
|
+
images_base64: list[str] | None = None, is_json: bool = False) -> tuple[str, str]:
|
|
381
|
+
cache_key_part_images = images_base64 or []
|
|
382
|
+
json_suffix = "_json" if is_json else ""
|
|
383
|
+
cache_key = self._cacher.create_cache_key([f"llm_completion{json_suffix}"],
|
|
384
|
+
[system_prompt, prompt, str(model_config),
|
|
385
|
+
"_".join(cache_key_part_images)])
|
|
386
|
+
cache_key_usage = self._cacher.create_cache_key([f"llm_completion{json_suffix}__usage"],
|
|
387
|
+
[system_prompt, prompt, str(model_config),
|
|
388
|
+
"_".join(cache_key_part_images)])
|
|
389
|
+
|
|
390
|
+
return cache_key, cache_key_usage
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
def prepare_chat_messages(prompt: str, system_prompt: str | None = None, images_base64: list[str] | None = None,
|
|
394
|
+
messages_history: list[ChatMessage] | None = None, envelope: bool = False) -> list[ChatMessage] | ChatMessages:
|
|
395
|
+
messages_all = messages_history or []
|
|
396
|
+
if system_prompt:
|
|
397
|
+
if messages_all and len(messages_all) > 0:
|
|
398
|
+
if not messages_all[0]["role"] == "system":
|
|
399
|
+
messages_all = [{"role": "system", "content": system_prompt}] + messages_all
|
|
400
|
+
else:
|
|
401
|
+
messages_all = [{"role": "system", "content": system_prompt}]
|
|
402
|
+
|
|
403
|
+
user_message: ChatMessage = {"role": "user", "content": prompt}
|
|
404
|
+
|
|
405
|
+
if images_base64:
|
|
406
|
+
user_message = {"role": "user", "content": []}
|
|
407
|
+
user_message["content"].append({"type": "input_text", "text": prompt})
|
|
408
|
+
for b64_image in images_base64:
|
|
409
|
+
user_message["content"].append({"type": "input_message", "image_url": f"data:image/png;base64,{b64_image}"})
|
|
410
|
+
|
|
411
|
+
messages_all.append(user_message)
|
|
412
|
+
|
|
413
|
+
if envelope:
|
|
414
|
+
return ChatMessages(messages=messages_all)
|
|
415
|
+
|
|
416
|
+
return messages_all
|
|
417
|
+
|
|
418
|
+
|
|
419
|
+
def testing():
|
|
420
|
+
from toolchemy.utils.locations import Locations
|
|
421
|
+
locations = Locations()
|
|
422
|
+
data_path = locations.in_resources("tests/ai/malformed.json")
|
|
423
|
+
data_str = locations.read_content(data_path)
|
|
424
|
+
|
|
425
|
+
data = LLMClientBase._decode_json(data_str)
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+
if __name__ == "__main__":
|
|
429
|
+
testing()
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from toolchemy.ai.clients.common import ILLMClient, ModelConfig, Usage
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class DummyModelClient(ILLMClient):
|
|
6
|
+
def __init__(self, name: str = "dummy", fixed_response: str | None = None):
|
|
7
|
+
self._name = name
|
|
8
|
+
self._fixed_response = fixed_response
|
|
9
|
+
self._metadata = {"name": self._name}
|
|
10
|
+
self._usages = []
|
|
11
|
+
|
|
12
|
+
def name(self) -> str:
|
|
13
|
+
return self._name
|
|
14
|
+
|
|
15
|
+
@property
|
|
16
|
+
def metadata(self) -> dict:
|
|
17
|
+
return self._metadata
|
|
18
|
+
|
|
19
|
+
def usage(self, tail: int | None = None) -> list[Usage]:
|
|
20
|
+
tail = tail or len(self._usages)
|
|
21
|
+
usages = self._usages[-tail:]
|
|
22
|
+
return usages
|
|
23
|
+
|
|
24
|
+
@property
|
|
25
|
+
def system_prompt(self) -> str:
|
|
26
|
+
return "You are a dummy AI Assistant"
|
|
27
|
+
|
|
28
|
+
def embeddings(self, text: str, model_name: str = "nomic-embed-text") -> list[float]:
|
|
29
|
+
return 32 * [0.98]
|
|
30
|
+
|
|
31
|
+
def _completion(self, prompt: str, system_prompt: str | None, model_config: ModelConfig | None = None,
|
|
32
|
+
images_base64: list[str] | None = None) -> tuple[str, Usage]:
|
|
33
|
+
if self._fixed_response:
|
|
34
|
+
model_response = self._fixed_response
|
|
35
|
+
else:
|
|
36
|
+
model_response = f"Echo: {prompt}"
|
|
37
|
+
return model_response, Usage(input_tokens=0, output_tokens=0, duration=0.0)
|
|
38
|
+
|
|
39
|
+
def completion(self, prompt: str, model_config: ModelConfig | None = None,
|
|
40
|
+
images_base64: list[str] | None = None, no_cache: bool = False, cache_only: bool = False) -> str:
|
|
41
|
+
response, usage = self._completion(prompt=prompt, system_prompt=self.system_prompt)
|
|
42
|
+
self._usages.append(usage)
|
|
43
|
+
return response
|
|
44
|
+
|
|
45
|
+
def completion_json(self, prompt: str, model_config: ModelConfig | None = None,
|
|
46
|
+
images_base64: list[str] | None = None, validation_schema: dict | None = None,
|
|
47
|
+
no_cache: bool = False, cache_only: bool = False) -> dict | list[dict]:
|
|
48
|
+
result_str = self.completion(prompt=prompt, model_config=model_config, images_base64=images_base64, cache_only=cache_only)
|
|
49
|
+
return json.loads(result_str)
|
|
50
|
+
|
|
51
|
+
def model_config(self, base_config: ModelConfig | None = None,
|
|
52
|
+
default_model_name: str | None = None) -> ModelConfig:
|
|
53
|
+
return base_config
|
|
54
|
+
|
|
55
|
+
@property
|
|
56
|
+
def usage_summary(self) -> dict:
|
|
57
|
+
return {
|
|
58
|
+
"input_tokens": 0,
|
|
59
|
+
"output_tokens": 0,
|
|
60
|
+
"total_tokens": 0,
|
|
61
|
+
}
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
from toolchemy.ai.clients.common import LLMClientBase, ModelConfig
|
|
4
|
+
from toolchemy.ai.clients import OpenAIClient, OllamaClient, GeminiClient, DummyModelClient
|
|
5
|
+
from toolchemy.utils.logger import get_logger
|
|
6
|
+
|
|
7
|
+
URI_OPENAI = "openai"
|
|
8
|
+
URI_GEMINI = "gemini"
|
|
9
|
+
URI_DUMMY = "dummy"
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def create_llm(name: str, uri: str | None = None, api_key: str | None = None, default_model_config: ModelConfig | None = None, system_prompt: str | None = None, log_level: int = logging.INFO, no_cache: bool = False) -> LLMClientBase:
|
|
13
|
+
logger = get_logger(level=log_level)
|
|
14
|
+
logger.debug(f"Creating llm instance")
|
|
15
|
+
logger.debug(f"> name: {name}")
|
|
16
|
+
if name.startswith("gpt") and not name.startswith("gpt-oss"):
|
|
17
|
+
uri = URI_OPENAI
|
|
18
|
+
elif name.startswith("gemini"):
|
|
19
|
+
uri = URI_GEMINI
|
|
20
|
+
elif uri is None:
|
|
21
|
+
raise ValueError(f"Cannot assume the LLM provider based on the model name: '{name}'. You can pass the uri explicitly as parameter for this function.'")
|
|
22
|
+
logger.debug(f"> uri: {uri}")
|
|
23
|
+
logger.debug(f"> uri assumed: {uri}")
|
|
24
|
+
|
|
25
|
+
if uri == URI_OPENAI:
|
|
26
|
+
if not api_key:
|
|
27
|
+
raise ValueError(f"you must pass the 'api_key' explicitly as parameter for this function.")
|
|
28
|
+
return OpenAIClient(model_name=name, api_key=api_key, system_prompt=system_prompt, default_model_config=default_model_config, no_cache=no_cache)
|
|
29
|
+
|
|
30
|
+
if uri == URI_GEMINI:
|
|
31
|
+
if not api_key:
|
|
32
|
+
raise ValueError(f"you must pass the 'api_key' explicitly as parameter for this function.")
|
|
33
|
+
return GeminiClient(default_model_name=name, api_key=api_key, system_prompt=system_prompt, default_model_config=default_model_config,
|
|
34
|
+
disable_cache=no_cache, log_level=log_level)
|
|
35
|
+
|
|
36
|
+
return OllamaClient(uri=uri, model_name=name, system_prompt=system_prompt, default_model_config=default_model_config, truncate_log_messages_to=2000,
|
|
37
|
+
disable_cache=no_cache, log_level=log_level)
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
import time
|
|
2
|
+
import logging
|
|
3
|
+
from google import genai
|
|
4
|
+
from google.genai import types
|
|
5
|
+
|
|
6
|
+
from toolchemy.utils.cacher import ICacher
|
|
7
|
+
from toolchemy.ai.clients.common import LLMClientBase, ModelConfig, Usage
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class GeminiClient(LLMClientBase):
|
|
11
|
+
def __init__(self, api_key: str, default_model_name: str | None = None, default_embedding_model_name: str | None = None,
|
|
12
|
+
default_model_config: ModelConfig | None = None,
|
|
13
|
+
system_prompt: str | None = None, keep_chat_session: bool = False,
|
|
14
|
+
retry_attempts: int = 5, retry_min_wait: int = 2, retry_max_wait: int = 60,
|
|
15
|
+
truncate_log_messages_to: int = 200,
|
|
16
|
+
fix_malformed_json: bool = True,
|
|
17
|
+
cacher: ICacher | None = None, disable_cache: bool = False, log_level: int = logging.INFO):
|
|
18
|
+
super().__init__(default_model_name=default_model_name, default_embedding_model_name=default_embedding_model_name,
|
|
19
|
+
default_model_config=default_model_config,
|
|
20
|
+
system_prompt=system_prompt, keep_chat_session=keep_chat_session,
|
|
21
|
+
retry_attempts=retry_attempts, retry_min_wait=retry_min_wait, retry_max_wait=retry_max_wait,
|
|
22
|
+
truncate_log_messages_to=truncate_log_messages_to,
|
|
23
|
+
fix_malformed_json=fix_malformed_json,
|
|
24
|
+
cacher=cacher, disable_cache=disable_cache, log_level=log_level)
|
|
25
|
+
self._client = genai.Client(api_key=api_key)
|
|
26
|
+
self._logger.debug(f"Gemini client has been initialized")
|
|
27
|
+
|
|
28
|
+
def embeddings(self, text: str) -> list[float]:
|
|
29
|
+
raise NotImplementedError()
|
|
30
|
+
|
|
31
|
+
def _completion(self, prompt: str, system_prompt: str | None, model_config: ModelConfig | None = None,
|
|
32
|
+
images_base64: list[str] | None = None) -> tuple[str, Usage]:
|
|
33
|
+
system_prompt = system_prompt or self._system_prompt
|
|
34
|
+
|
|
35
|
+
duration_time_start = time.time()
|
|
36
|
+
|
|
37
|
+
response = self._client.models.generate_content(
|
|
38
|
+
model=model_config.model_name,
|
|
39
|
+
config=types.GenerateContentConfig(system_instruction=system_prompt),
|
|
40
|
+
contents=prompt
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
duration = time.time() - duration_time_start
|
|
44
|
+
|
|
45
|
+
usage = Usage(input_tokens=response.usage_metadata.prompt_token_count,
|
|
46
|
+
output_tokens=response.usage_metadata.total_token_count - response.usage_metadata.prompt_token_count, duration=duration)
|
|
47
|
+
|
|
48
|
+
return response.text, usage
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from ollama import Client
|
|
3
|
+
|
|
4
|
+
from toolchemy.ai.clients.common import LLMClientBase, ModelConfig, Usage
|
|
5
|
+
from toolchemy.utils.cacher import ICacher
|
|
6
|
+
from toolchemy.utils.datestimes import Seconds
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class OllamaClient(LLMClientBase):
|
|
10
|
+
def __init__(self, uri: str, model_name: str | None = None, embedding_model_name: str | None = "nomic-embed-text",
|
|
11
|
+
default_model_config: ModelConfig | None = None, system_prompt: str | None = None,
|
|
12
|
+
keep_chat_session: bool = False,
|
|
13
|
+
retry_attempts: int = 5, retry_min_wait: int = 2, retry_max_wait: int = 60,
|
|
14
|
+
truncate_log_messages_to: int = 200, fix_malformed_json: bool = True,
|
|
15
|
+
cacher: ICacher | None = None, disable_cache: bool = False, log_level: int = logging.INFO):
|
|
16
|
+
super().__init__(default_model_name=model_name, default_embedding_model_name=embedding_model_name,
|
|
17
|
+
default_model_config=default_model_config,
|
|
18
|
+
system_prompt=system_prompt, keep_chat_session=keep_chat_session,
|
|
19
|
+
retry_attempts=retry_attempts, retry_min_wait=retry_min_wait, retry_max_wait=retry_max_wait,
|
|
20
|
+
truncate_log_messages_to=truncate_log_messages_to, fix_malformed_json=fix_malformed_json,
|
|
21
|
+
cacher=cacher, disable_cache=disable_cache, log_level=log_level)
|
|
22
|
+
self._uri = uri
|
|
23
|
+
self._metadata["uri"] = self._uri
|
|
24
|
+
assert self._uri, f"The model uri cannot be empty!"
|
|
25
|
+
|
|
26
|
+
self._client = Client(host=self._uri)
|
|
27
|
+
self._logger.debug(f"OLlama client has been initialized ({self._uri})")
|
|
28
|
+
|
|
29
|
+
def embeddings(self, text: str) -> list[float]:
|
|
30
|
+
cache_key = self._cacher.create_cache_key(["llm_embeddings"], [text])
|
|
31
|
+
if self._cacher.exists(cache_key):
|
|
32
|
+
self._logger.debug(f"Cache for the text embeddings already exists")
|
|
33
|
+
return self._cacher.get(cache_key)
|
|
34
|
+
|
|
35
|
+
results_raw = self._client.embed(model=self.embedding_name, input=text)
|
|
36
|
+
results = [v for v in results_raw.embeddings[0]]
|
|
37
|
+
|
|
38
|
+
self._cacher.set(cache_key, results)
|
|
39
|
+
|
|
40
|
+
return results
|
|
41
|
+
|
|
42
|
+
def _completion(self, prompt: str, system_prompt: str | None = None, model_config: ModelConfig | None = None,
|
|
43
|
+
images_base64: list[str] | None = None) -> tuple[str, Usage]:
|
|
44
|
+
|
|
45
|
+
system_prompt = system_prompt or self.system_prompt
|
|
46
|
+
result = self._client.generate(model=model_config.model_name, system=system_prompt, prompt=prompt,
|
|
47
|
+
options={
|
|
48
|
+
"temperature": model_config.temperature,
|
|
49
|
+
"top_p": model_config.top_p,
|
|
50
|
+
"num_predict": model_config.max_new_tokens,
|
|
51
|
+
}, images=images_base64)
|
|
52
|
+
|
|
53
|
+
total_duration_s = result.total_duration * Seconds.NANOSECOND
|
|
54
|
+
usage = Usage(input_tokens=result.prompt_eval_count, output_tokens=result.eval_count, duration=total_duration_s)
|
|
55
|
+
|
|
56
|
+
self._logger.debug(f"Completion done.")
|
|
57
|
+
|
|
58
|
+
return result.response, usage
|