toolchemy 0.2.185__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- toolchemy/__main__.py +9 -0
- toolchemy/ai/clients/__init__.py +20 -0
- toolchemy/ai/clients/common.py +429 -0
- toolchemy/ai/clients/dummy_model_client.py +61 -0
- toolchemy/ai/clients/factory.py +37 -0
- toolchemy/ai/clients/gemini_client.py +48 -0
- toolchemy/ai/clients/ollama_client.py +58 -0
- toolchemy/ai/clients/openai_client.py +76 -0
- toolchemy/ai/clients/pricing.py +66 -0
- toolchemy/ai/clients/whisper_client.py +141 -0
- toolchemy/ai/prompter.py +124 -0
- toolchemy/ai/trackers/__init__.py +5 -0
- toolchemy/ai/trackers/common.py +216 -0
- toolchemy/ai/trackers/mlflow_tracker.py +221 -0
- toolchemy/ai/trackers/neptune_tracker.py +135 -0
- toolchemy/db/lightdb.py +260 -0
- toolchemy/utils/__init__.py +19 -0
- toolchemy/utils/at_exit_collector.py +109 -0
- toolchemy/utils/cacher/__init__.py +20 -0
- toolchemy/utils/cacher/cacher_diskcache.py +121 -0
- toolchemy/utils/cacher/cacher_pickle.py +152 -0
- toolchemy/utils/cacher/cacher_shelve.py +196 -0
- toolchemy/utils/cacher/common.py +174 -0
- toolchemy/utils/datestimes.py +77 -0
- toolchemy/utils/locations.py +111 -0
- toolchemy/utils/logger.py +76 -0
- toolchemy/utils/timer.py +23 -0
- toolchemy/utils/utils.py +168 -0
- toolchemy/vision/__init__.py +5 -0
- toolchemy/vision/caption_overlay.py +77 -0
- toolchemy/vision/image.py +89 -0
- toolchemy-0.2.185.dist-info/METADATA +25 -0
- toolchemy-0.2.185.dist-info/RECORD +36 -0
- toolchemy-0.2.185.dist-info/WHEEL +4 -0
- toolchemy-0.2.185.dist-info/entry_points.txt +3 -0
- toolchemy-0.2.185.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
import time
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from openai import OpenAI, AzureOpenAI, NOT_GIVEN
|
|
4
|
+
from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
|
|
5
|
+
from typing import Iterable
|
|
6
|
+
|
|
7
|
+
from toolchemy.ai.clients.common import LLMClientBase, ModelConfig, Usage, prepare_chat_messages
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class BaseOpenAIClient(LLMClientBase, ABC):
|
|
11
|
+
@property
|
|
12
|
+
@abstractmethod
|
|
13
|
+
def _client(self) -> OpenAI:
|
|
14
|
+
pass
|
|
15
|
+
|
|
16
|
+
def embeddings(self, text: str) -> list[float]:
|
|
17
|
+
response = self._client.embeddings.create(
|
|
18
|
+
input=text,
|
|
19
|
+
model=self.embedding_name
|
|
20
|
+
)
|
|
21
|
+
return response.data[0].embedding
|
|
22
|
+
|
|
23
|
+
def _completion(self, prompt: str, system_prompt: str | None, model_config: ModelConfig | None = None,
|
|
24
|
+
images_base64: list[str] | None = None) -> tuple[str, Usage]:
|
|
25
|
+
messages = self._prepare_chat_messages(prompt=prompt, system_prompt=system_prompt, images_base64=images_base64)
|
|
26
|
+
|
|
27
|
+
duration_time_start = time.time()
|
|
28
|
+
response = self._client.chat.completions.create(
|
|
29
|
+
model=model_config.model_name,
|
|
30
|
+
messages=messages,
|
|
31
|
+
top_p=model_config.top_p,
|
|
32
|
+
)
|
|
33
|
+
duration = time.time() - duration_time_start
|
|
34
|
+
|
|
35
|
+
usage = Usage(input_tokens=response.usage.prompt_tokens, output_tokens=response.usage.completion_tokens,
|
|
36
|
+
duration=duration)
|
|
37
|
+
|
|
38
|
+
return response.choices[0].message.content, usage
|
|
39
|
+
|
|
40
|
+
def _prepare_chat_messages(self, prompt: str, system_prompt: str | None = None, images_base64: list[str] | None = None) -> list[ChatCompletionMessageParam]:
|
|
41
|
+
messages_all = []
|
|
42
|
+
system_prompt = system_prompt or self._system_prompt or NOT_GIVEN
|
|
43
|
+
|
|
44
|
+
if self._keep_chat_session:
|
|
45
|
+
messages_all.extend(self._session_messages)
|
|
46
|
+
|
|
47
|
+
return prepare_chat_messages(prompt=prompt, system_prompt=system_prompt, images_base64=images_base64, messages_history=messages_all)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class OpenAIClient(BaseOpenAIClient):
|
|
51
|
+
def __init__(self, api_key: str, model_name: str = "gpt-3.5-turbo",
|
|
52
|
+
embedding_model_name: str = "text-embedding-3-large", default_model_config: ModelConfig | None = None,
|
|
53
|
+
system_prompt: str | None = None, keep_chat_session: bool = False, no_cache: bool = False):
|
|
54
|
+
super().__init__(default_model_name=model_name, default_embedding_model_name=embedding_model_name,
|
|
55
|
+
default_model_config=default_model_config,
|
|
56
|
+
system_prompt=system_prompt, keep_chat_session=keep_chat_session, disable_cache=no_cache)
|
|
57
|
+
self._openai_client = OpenAI(api_key=api_key)
|
|
58
|
+
|
|
59
|
+
@property
|
|
60
|
+
def _client(self) -> OpenAI:
|
|
61
|
+
return self._openai_client
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class AzureOpenAIClient(BaseOpenAIClient):
|
|
65
|
+
def __init__(self, api_key: str, api_endpoint: str, api_version: str, model_name: str = "gpt-3.5-turbo",
|
|
66
|
+
embedding_model_name: str = "text-embedding-3-large", default_model_config: ModelConfig | None = None,
|
|
67
|
+
system_prompt: str | None = None, keep_chat_session: bool = False):
|
|
68
|
+
super().__init__(default_model_name=model_name, default_embedding_model_name=embedding_model_name,
|
|
69
|
+
default_model_config=default_model_config,
|
|
70
|
+
system_prompt=system_prompt, keep_chat_session=keep_chat_session)
|
|
71
|
+
self._openai_client = AzureOpenAI(api_key=api_key, azure_endpoint=api_endpoint, api_version=api_version,
|
|
72
|
+
azure_deployment=model_name)
|
|
73
|
+
|
|
74
|
+
@property
|
|
75
|
+
def _client(self) -> OpenAI:
|
|
76
|
+
return self._openai_client
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
KEY_INPUT_TOKENS = "input_tokens_cost"
|
|
2
|
+
KEY_OUTPUT_TOKENS = "output_tokens_cost"
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class Pricing:
|
|
7
|
+
pricing_per_1_mln = {
|
|
8
|
+
"gpt-5.2": {
|
|
9
|
+
KEY_INPUT_TOKENS: 1.75,
|
|
10
|
+
KEY_OUTPUT_TOKENS: 14.00,
|
|
11
|
+
},
|
|
12
|
+
"gpt-5.2-pro": {
|
|
13
|
+
KEY_INPUT_TOKENS: 21.00,
|
|
14
|
+
KEY_OUTPUT_TOKENS: 168.00,
|
|
15
|
+
},
|
|
16
|
+
"gpt-5-mini": {
|
|
17
|
+
KEY_INPUT_TOKENS: 0.25,
|
|
18
|
+
KEY_OUTPUT_TOKENS: 2.00,
|
|
19
|
+
},
|
|
20
|
+
"gpt-4.1": {
|
|
21
|
+
KEY_INPUT_TOKENS: 3.00,
|
|
22
|
+
KEY_OUTPUT_TOKENS: 12.00,
|
|
23
|
+
},
|
|
24
|
+
"gpt-4.1-mini": {
|
|
25
|
+
KEY_INPUT_TOKENS: 0.8,
|
|
26
|
+
KEY_OUTPUT_TOKENS: 3.2,
|
|
27
|
+
},
|
|
28
|
+
"gpt-4.1-nano": {
|
|
29
|
+
KEY_INPUT_TOKENS: 0.2,
|
|
30
|
+
KEY_OUTPUT_TOKENS: 0.8,
|
|
31
|
+
},
|
|
32
|
+
"o4-mini": {
|
|
33
|
+
KEY_INPUT_TOKENS: 4.00,
|
|
34
|
+
KEY_OUTPUT_TOKENS: 16.00,
|
|
35
|
+
},
|
|
36
|
+
"mistral-small3.2:24b": {
|
|
37
|
+
KEY_INPUT_TOKENS: 0.01,
|
|
38
|
+
KEY_OUTPUT_TOKENS: 0.03,
|
|
39
|
+
},
|
|
40
|
+
"gpt-oss:120b": {
|
|
41
|
+
KEY_INPUT_TOKENS: 0.01,
|
|
42
|
+
KEY_OUTPUT_TOKENS: 0.03,
|
|
43
|
+
},
|
|
44
|
+
"gemma3:27b": {
|
|
45
|
+
KEY_INPUT_TOKENS: 0.01,
|
|
46
|
+
KEY_OUTPUT_TOKENS: 0.03,
|
|
47
|
+
},
|
|
48
|
+
"qwen3:32b-q8_0": {
|
|
49
|
+
KEY_INPUT_TOKENS: 0.01,
|
|
50
|
+
KEY_OUTPUT_TOKENS: 0.03,
|
|
51
|
+
},
|
|
52
|
+
"dummy-model": {
|
|
53
|
+
KEY_INPUT_TOKENS: 0.1,
|
|
54
|
+
KEY_OUTPUT_TOKENS: 1.0,
|
|
55
|
+
}
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
@classmethod
|
|
59
|
+
def estimate(cls, model_name: str, input_tokens: int, output_tokens: int) -> float:
|
|
60
|
+
if model_name not in cls.pricing_per_1_mln:
|
|
61
|
+
raise ValueError(f"Model '{model_name}' not supported for pricing estimation")
|
|
62
|
+
|
|
63
|
+
model_pricing = cls.pricing_per_1_mln[model_name]
|
|
64
|
+
input_cost = model_pricing[KEY_INPUT_TOKENS] * (input_tokens / 1_000_000)
|
|
65
|
+
output_cost = model_pricing[KEY_OUTPUT_TOKENS] * (output_tokens / 1_000_000)
|
|
66
|
+
return input_cost + output_cost
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import requests
|
|
3
|
+
import sys
|
|
4
|
+
import os
|
|
5
|
+
import subprocess
|
|
6
|
+
import tempfile
|
|
7
|
+
from wyoming.client import AsyncClient
|
|
8
|
+
from wyoming.audio import AudioChunk, AudioStart, AudioStop, AudioChunkConverter
|
|
9
|
+
from wyoming.asr import Transcript
|
|
10
|
+
from wyoming.ping import Ping
|
|
11
|
+
from wyoming.info import Describe
|
|
12
|
+
import wave
|
|
13
|
+
|
|
14
|
+
from toolchemy.utils.logger import get_logger
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class WhisperClient:
|
|
18
|
+
def __init__(self, url: str):
|
|
19
|
+
self._logger = get_logger()
|
|
20
|
+
self._endpoint = url
|
|
21
|
+
self._whisper_client_wyoming = None
|
|
22
|
+
if self._endpoint.startswith("http"):
|
|
23
|
+
if not self._endpoint.endswith("transcribe"):
|
|
24
|
+
if not self._endpoint.endswith("/"):
|
|
25
|
+
self._endpoint += "/"
|
|
26
|
+
self._endpoint += "transcribe"
|
|
27
|
+
elif self._endpoint.startswith("tcp"):
|
|
28
|
+
self._whisper_client_wyoming = AsyncClient.from_uri(self._endpoint)
|
|
29
|
+
else:
|
|
30
|
+
raise ValueError(f"Unknown protocol for the whisper server endpoint: '{self._endpoint}'")
|
|
31
|
+
self._logger.info(f"Whisper client initialized (endpoint: '{self._endpoint}')")
|
|
32
|
+
|
|
33
|
+
def transcribe(self, audio_path: str) -> str:
|
|
34
|
+
transcription = None
|
|
35
|
+
|
|
36
|
+
if self._endpoint.startswith("tcp"):
|
|
37
|
+
transcription = asyncio.run(self._transcribe_wyoming(audio_path))
|
|
38
|
+
|
|
39
|
+
if self._endpoint.startswith("http"):
|
|
40
|
+
transcription = self._transcribe_http(audio_path)
|
|
41
|
+
|
|
42
|
+
if transcription is None:
|
|
43
|
+
raise RuntimeError(f"Transcription failed...")
|
|
44
|
+
|
|
45
|
+
return transcription.strip()
|
|
46
|
+
|
|
47
|
+
def _transcribe_http(self, audio_path: str) -> str:
|
|
48
|
+
if not os.path.exists(audio_path):
|
|
49
|
+
raise ValueError(f"Error: File '{audio_path}' not found.")
|
|
50
|
+
|
|
51
|
+
with open(audio_path, "rb") as audio_file:
|
|
52
|
+
files = {"file": audio_file}
|
|
53
|
+
|
|
54
|
+
self._logger.info(f"Sending '{audio_path}' to Whisper server...")
|
|
55
|
+
response = requests.post(self._endpoint, files=files)
|
|
56
|
+
|
|
57
|
+
if response.status_code == 200:
|
|
58
|
+
result = response.json()
|
|
59
|
+
result_transcription = result.get("text")
|
|
60
|
+
self._logger.info(f"Transcription: '{result_transcription}'")
|
|
61
|
+
return result_transcription
|
|
62
|
+
|
|
63
|
+
err_msg = f"Error: Failed to transcribe. Status Code: {response.status_code}"
|
|
64
|
+
self._logger.error(err_msg)
|
|
65
|
+
raise RuntimeError(err_msg)
|
|
66
|
+
|
|
67
|
+
async def _transcribe_wyoming(self, audio_path: str, audio_rate: int = 16000, audio_width: int = 2,
|
|
68
|
+
audio_channels: int = 1, chunk_size: int = 1024) -> str:
|
|
69
|
+
wav_path = self._convert_to_wav(audio_path, audio_rate=audio_rate, audio_channels=audio_channels)
|
|
70
|
+
|
|
71
|
+
await self._whisper_client_wyoming.connect()
|
|
72
|
+
await self._whisper_client_wyoming.write_event(Ping(text="test").event())
|
|
73
|
+
|
|
74
|
+
await self._whisper_client_wyoming.write_event(Describe().event())
|
|
75
|
+
|
|
76
|
+
info_event = await self._whisper_client_wyoming.read_event()
|
|
77
|
+
self._logger.info(f"info event: {info_event}")
|
|
78
|
+
|
|
79
|
+
with wave.open(wav_path, 'rb') as wav:
|
|
80
|
+
assert wav.getframerate() == audio_rate
|
|
81
|
+
assert wav.getsampwidth() == audio_width
|
|
82
|
+
assert wav.getnchannels() == audio_channels
|
|
83
|
+
|
|
84
|
+
await self._whisper_client_wyoming.write_event(AudioStart(audio_rate, audio_width, audio_channels).event())
|
|
85
|
+
|
|
86
|
+
audio_bytes = wav.readframes(chunk_size)
|
|
87
|
+
|
|
88
|
+
converter = AudioChunkConverter(
|
|
89
|
+
rate=audio_rate, width=audio_width, channels=audio_channels,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
while audio_bytes:
|
|
93
|
+
chunk = converter.convert(AudioChunk(audio_rate, audio_width, audio_channels, audio_bytes))
|
|
94
|
+
await self._whisper_client_wyoming.write_event(chunk.event())
|
|
95
|
+
audio_bytes = wav.readframes(chunk_size)
|
|
96
|
+
|
|
97
|
+
await self._whisper_client_wyoming.write_event(AudioStop().event())
|
|
98
|
+
|
|
99
|
+
while True:
|
|
100
|
+
event = await asyncio.wait_for(self._whisper_client_wyoming.read_event(), timeout=30)
|
|
101
|
+
if event is None:
|
|
102
|
+
break
|
|
103
|
+
transcript = Transcript.from_event(event)
|
|
104
|
+
if transcript.text:
|
|
105
|
+
self._logger.info(f"Transcription: {transcript.text}")
|
|
106
|
+
break
|
|
107
|
+
await self._whisper_client_wyoming.disconnect()
|
|
108
|
+
|
|
109
|
+
return transcript.text
|
|
110
|
+
|
|
111
|
+
def _convert_to_wav(self, input_path: str, audio_rate: int, audio_channels: int) -> str:
|
|
112
|
+
temp_wav = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
|
|
113
|
+
temp_wav.close()
|
|
114
|
+
|
|
115
|
+
cmd = [
|
|
116
|
+
"ffmpeg",
|
|
117
|
+
"-y",
|
|
118
|
+
"-i", input_path,
|
|
119
|
+
"-ar", str(audio_rate),
|
|
120
|
+
"-ac", str(audio_channels),
|
|
121
|
+
"-f", "wav",
|
|
122
|
+
"-sample_fmt", "s16",
|
|
123
|
+
temp_wav.name
|
|
124
|
+
]
|
|
125
|
+
|
|
126
|
+
subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
|
127
|
+
return temp_wav.name
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def main(argv: list):
|
|
131
|
+
if len(argv) < 2:
|
|
132
|
+
raise ValueError("Usage: python transcribe_audio.py <audio_file.mp3/wav>")
|
|
133
|
+
|
|
134
|
+
file_path = argv[1]
|
|
135
|
+
client = WhisperClient(url="tcp://hal:10300")
|
|
136
|
+
transcription = client.transcribe(file_path)
|
|
137
|
+
print(f"> '{transcription}'")
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
if __name__ == "__main__":
|
|
141
|
+
main(sys.argv)
|
toolchemy/ai/prompter.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import mlflow
|
|
3
|
+
import sys
|
|
4
|
+
import subprocess
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
|
|
7
|
+
from toolchemy.utils.cacher import Cacher, ICacher
|
|
8
|
+
from toolchemy.utils.logger import get_logger
|
|
9
|
+
from toolchemy.utils.locations import Locations
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class IPrompter(ABC):
|
|
13
|
+
@abstractmethod
|
|
14
|
+
def render(self, name: str, version: str | dict[str, str | list[str]] | None = None, **variables) -> str:
|
|
15
|
+
pass
|
|
16
|
+
|
|
17
|
+
@abstractmethod
|
|
18
|
+
def template(self, name: str, version: str | dict[str, str | list[str]] | None = None) -> str:
|
|
19
|
+
pass
|
|
20
|
+
|
|
21
|
+
@abstractmethod
|
|
22
|
+
def create_template(self, name: str, template: str, overwrite: bool = False):
|
|
23
|
+
pass
|
|
24
|
+
|
|
25
|
+
@abstractmethod
|
|
26
|
+
def run_studio(self):
|
|
27
|
+
pass
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def run_studio():
|
|
31
|
+
prompter = PrompterMLflow()
|
|
32
|
+
prompter.run_studio()
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class PrompterBase(IPrompter, ABC):
|
|
36
|
+
def __init__(self, cacher: ICacher | None = None, no_cache: bool = False, log_level: int = logging.INFO):
|
|
37
|
+
self._logger = get_logger(level=log_level)
|
|
38
|
+
self._cacher = cacher or Cacher(disabled=no_cache)
|
|
39
|
+
|
|
40
|
+
def _prompt_version(self, name: str, version_mapping: str | dict[str, str | list[str]] | None = None) -> str | None:
|
|
41
|
+
if version_mapping is None:
|
|
42
|
+
return self._latest_value()
|
|
43
|
+
if isinstance(version_mapping, int):
|
|
44
|
+
return str(version_mapping)
|
|
45
|
+
if isinstance(version_mapping, str):
|
|
46
|
+
return version_mapping
|
|
47
|
+
if name not in version_mapping:
|
|
48
|
+
return None
|
|
49
|
+
if isinstance(version_mapping[name], str):
|
|
50
|
+
return version_mapping[name]
|
|
51
|
+
return version_mapping[name][0]
|
|
52
|
+
|
|
53
|
+
@abstractmethod
|
|
54
|
+
def _latest_value(self) -> str | None:
|
|
55
|
+
pass
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class PrompterMLflow(PrompterBase):
|
|
59
|
+
DEFAULT_PROMPT_REGISTRY_NAME = "prompts_mlflow"
|
|
60
|
+
|
|
61
|
+
def __init__(self, registry_store_dir: str | None = None, cacher: ICacher | None = None, no_cache: bool = False, log_level: int = logging.INFO):
|
|
62
|
+
super().__init__(cacher=cacher, no_cache=no_cache, log_level=log_level)
|
|
63
|
+
locations = Locations()
|
|
64
|
+
if registry_store_dir is None:
|
|
65
|
+
registry_store_dir = locations.in_root(self.DEFAULT_PROMPT_REGISTRY_NAME)
|
|
66
|
+
registry_store_dir = locations.abs(registry_store_dir).rstrip("/")
|
|
67
|
+
|
|
68
|
+
self._registry_store_uri = f"sqlite:///{registry_store_dir}/registry.db"
|
|
69
|
+
self._tracking_uri = f"sqlite:///{registry_store_dir}/tracking.db"
|
|
70
|
+
mlflow.set_tracking_uri(self._tracking_uri)
|
|
71
|
+
mlflow.set_registry_uri(self._registry_store_uri)
|
|
72
|
+
|
|
73
|
+
self._logger.info(f"Prompter-MLflow initialized")
|
|
74
|
+
self._logger.info(f"> tracking store uri: {self._tracking_uri}")
|
|
75
|
+
self._logger.info(f"> registry store uri: {self._registry_store_uri}")
|
|
76
|
+
|
|
77
|
+
def render(self, name: str, version: str | dict[str, str | list[str]] | None = None, **variables) -> str:
|
|
78
|
+
prompt_uri = self._build_prompt_uri(name=name, version=version)
|
|
79
|
+
|
|
80
|
+
cache_key = self._cacher.create_cache_key(["prompt_render", prompt_uri], [variables])
|
|
81
|
+
if self._cacher.exists(cache_key):
|
|
82
|
+
return self._cacher.get(cache_key)
|
|
83
|
+
|
|
84
|
+
self._logger.debug(f"Rendering prompt: '{name}' (version: '{version}') -> prompt uri: '{prompt_uri}'")
|
|
85
|
+
|
|
86
|
+
prompt_template = mlflow.genai.load_prompt(prompt_uri)
|
|
87
|
+
prompt = prompt_template.format(**variables)
|
|
88
|
+
|
|
89
|
+
self._cacher.set(cache_key, prompt)
|
|
90
|
+
|
|
91
|
+
return prompt
|
|
92
|
+
|
|
93
|
+
def template(self, name: str, version: str | dict[str, str | list[str]] | None = None) -> str:
|
|
94
|
+
prompt_uri = self._build_prompt_uri(name=name, version=version)
|
|
95
|
+
|
|
96
|
+
cache_key = self._cacher.create_cache_key(["prompt_template", prompt_uri])
|
|
97
|
+
if self._cacher.exists(cache_key):
|
|
98
|
+
return self._cacher.get(cache_key)
|
|
99
|
+
|
|
100
|
+
self._logger.debug(f"Getting prompt template: '{name}' (version: '{version}') -> prompt uri: '{prompt_uri}'")
|
|
101
|
+
|
|
102
|
+
prompt_template = mlflow.genai.load_prompt(prompt_uri)
|
|
103
|
+
|
|
104
|
+
self._cacher.set(cache_key, prompt_template.template)
|
|
105
|
+
|
|
106
|
+
return prompt_template.template
|
|
107
|
+
|
|
108
|
+
def create_template(self, name: str, template: str, overwrite: bool = False):
|
|
109
|
+
if mlflow.genai.load_prompt(name_or_uri=name, allow_missing=True) and not overwrite:
|
|
110
|
+
return
|
|
111
|
+
mlflow.genai.register_prompt(name=name, template=template)
|
|
112
|
+
|
|
113
|
+
def _build_prompt_uri(self, name: str, version: str | int | None = None) -> str:
|
|
114
|
+
prompt_version = self._prompt_version(name, version_mapping=version)
|
|
115
|
+
prompt_uri = f"prompts:/{name}"
|
|
116
|
+
prompt_uri += f"@{prompt_version}" if isinstance(prompt_version, str) and not prompt_version.isdigit() else f"/{prompt_version}"
|
|
117
|
+
return prompt_uri
|
|
118
|
+
|
|
119
|
+
def run_studio(self):
|
|
120
|
+
command = [sys.executable, "-m", "mlflow", "ui", "--registry-store-uri", self._registry_store_uri, "--backend-store-uri", self._tracking_uri]
|
|
121
|
+
sys.exit(subprocess.call(command))
|
|
122
|
+
|
|
123
|
+
def _latest_value(self) -> str | None:
|
|
124
|
+
return "latest"
|
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
import statistics
|
|
3
|
+
from typing import Dict, Any
|
|
4
|
+
|
|
5
|
+
from toolchemy.utils.logger import get_logger
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class ITracker(ABC):
|
|
9
|
+
@property
|
|
10
|
+
@abstractmethod
|
|
11
|
+
def experiment_name(self) -> str:
|
|
12
|
+
pass
|
|
13
|
+
|
|
14
|
+
@property
|
|
15
|
+
@abstractmethod
|
|
16
|
+
def run_name(self) -> str:
|
|
17
|
+
pass
|
|
18
|
+
|
|
19
|
+
@abstractmethod
|
|
20
|
+
def start_run(
|
|
21
|
+
self, run_id: str = None,
|
|
22
|
+
run_name: str = None,
|
|
23
|
+
parent_run_id: str = None,
|
|
24
|
+
user_specified_tags: Dict[str, str] = None
|
|
25
|
+
):
|
|
26
|
+
pass
|
|
27
|
+
|
|
28
|
+
@abstractmethod
|
|
29
|
+
def end_run(self):
|
|
30
|
+
pass
|
|
31
|
+
|
|
32
|
+
@abstractmethod
|
|
33
|
+
def log(self, name: str, value: Any):
|
|
34
|
+
pass
|
|
35
|
+
|
|
36
|
+
@abstractmethod
|
|
37
|
+
def log_param(self, name: str, value: Any):
|
|
38
|
+
pass
|
|
39
|
+
|
|
40
|
+
@abstractmethod
|
|
41
|
+
def log_params(self, params: Dict[str, Any]):
|
|
42
|
+
pass
|
|
43
|
+
|
|
44
|
+
@abstractmethod
|
|
45
|
+
def log_metric(self, name: str, value: float, step: int | None = None, metric_metadata: dict | None = None):
|
|
46
|
+
pass
|
|
47
|
+
|
|
48
|
+
@abstractmethod
|
|
49
|
+
def log_text(self, name: str, value: str):
|
|
50
|
+
pass
|
|
51
|
+
|
|
52
|
+
@abstractmethod
|
|
53
|
+
def log_metrics(self, metrics: Dict[str, float | list], step: int | None = None):
|
|
54
|
+
pass
|
|
55
|
+
|
|
56
|
+
@abstractmethod
|
|
57
|
+
def log_artifact(self, artifact_path: str, save_dir: str = None):
|
|
58
|
+
pass
|
|
59
|
+
|
|
60
|
+
@abstractmethod
|
|
61
|
+
def log_figure(self, figure, save_path: str):
|
|
62
|
+
"""
|
|
63
|
+
Parameters:
|
|
64
|
+
figure (matplotlib.figure.Figure, _matplotlib.figure.Figure, _plotly.graph_objects.Figure): plot figure
|
|
65
|
+
save_path (str): the run-relative artifact file path in posixpath format to which
|
|
66
|
+
the figure is saved (e.g. "dir/file.png").
|
|
67
|
+
"""
|
|
68
|
+
pass
|
|
69
|
+
|
|
70
|
+
@abstractmethod
|
|
71
|
+
def set_run_tag(self, name: str, value: str | int | float):
|
|
72
|
+
pass
|
|
73
|
+
|
|
74
|
+
@abstractmethod
|
|
75
|
+
def set_experiment_tag(self, name: str, value: str | int | float):
|
|
76
|
+
pass
|
|
77
|
+
|
|
78
|
+
@abstractmethod
|
|
79
|
+
def disable(self):
|
|
80
|
+
pass
|
|
81
|
+
|
|
82
|
+
@abstractmethod
|
|
83
|
+
def get_data(self) -> dict:
|
|
84
|
+
pass
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class TrackerBase(ITracker, ABC):
|
|
88
|
+
def __init__(self, experiment_name: str, with_artifact_logging: bool = True, disabled: bool = False):
|
|
89
|
+
self._disabled = disabled
|
|
90
|
+
self._logger = get_logger()
|
|
91
|
+
self._experiment_name = experiment_name
|
|
92
|
+
self._artifact_logging = with_artifact_logging
|
|
93
|
+
self._metrics = {}
|
|
94
|
+
self._params = {}
|
|
95
|
+
self._tags = {
|
|
96
|
+
"experiment": {},
|
|
97
|
+
"runs": {},
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
@property
|
|
101
|
+
def experiment_name(self) -> str:
|
|
102
|
+
return self._experiment_name
|
|
103
|
+
|
|
104
|
+
def get_max_metric_value(self, name: str) -> float:
|
|
105
|
+
return max(self._metrics[name], key=lambda el: el['value'])
|
|
106
|
+
|
|
107
|
+
def get_min_metric_value(self, name: str) -> float:
|
|
108
|
+
return min(self._metrics[name], key=lambda el: el['value'])
|
|
109
|
+
|
|
110
|
+
def get_avg_metric_value(self, name: str) -> float:
|
|
111
|
+
metric_values = [m['value'] for m in self._metrics[name]]
|
|
112
|
+
return statistics.mean(metric_values)
|
|
113
|
+
|
|
114
|
+
def get_data(self) -> dict:
|
|
115
|
+
return {
|
|
116
|
+
"metrics": self._metrics.copy(),
|
|
117
|
+
"params": self._params.copy(),
|
|
118
|
+
"tags": self._tags.copy(),
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
def _store_param(self, name: str, value: Any):
|
|
122
|
+
if self._disabled:
|
|
123
|
+
raise RuntimeError(f"Disabled trackers cannot store params!")
|
|
124
|
+
|
|
125
|
+
self._params[name] = value
|
|
126
|
+
|
|
127
|
+
def _store_tag(self, name: str, value: str | int | float, run_name: str | None = None):
|
|
128
|
+
if run_name is None:
|
|
129
|
+
self._tags["experiment"][name] = value
|
|
130
|
+
return
|
|
131
|
+
if run_name not in self._tags["runs"]:
|
|
132
|
+
self._tags["runs"][run_name] = {}
|
|
133
|
+
self._tags["runs"][run_name][name] = value
|
|
134
|
+
|
|
135
|
+
def _store_metric(self, name: str, value: float, metric_metadata: dict | None = None) -> float:
|
|
136
|
+
if self._disabled:
|
|
137
|
+
raise RuntimeError(f"Disabled trackers cannot store metrics!")
|
|
138
|
+
if name not in self._metrics:
|
|
139
|
+
self._metrics[name] = []
|
|
140
|
+
|
|
141
|
+
if isinstance(value, dict):
|
|
142
|
+
new_entry = value
|
|
143
|
+
else:
|
|
144
|
+
new_entry = {
|
|
145
|
+
'value': value
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
if metric_metadata:
|
|
149
|
+
new_entry.update(metric_metadata)
|
|
150
|
+
|
|
151
|
+
self._metrics[name] += [new_entry]
|
|
152
|
+
|
|
153
|
+
return new_entry['value']
|
|
154
|
+
|
|
155
|
+
def disable(self):
|
|
156
|
+
self._disabled = True
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
class InMemoryTracker(TrackerBase):
|
|
160
|
+
def __init__(self, experiment_name: str = "dummy", disabled: bool = False):
|
|
161
|
+
super().__init__(experiment_name=experiment_name, disabled=disabled)
|
|
162
|
+
self._run_name = None
|
|
163
|
+
self._reset()
|
|
164
|
+
|
|
165
|
+
@property
|
|
166
|
+
def run_name(self) -> str:
|
|
167
|
+
return self._run_name
|
|
168
|
+
|
|
169
|
+
def start_run(
|
|
170
|
+
self, run_id: str = None,
|
|
171
|
+
run_name: str = None,
|
|
172
|
+
parent_run_id: str = None,
|
|
173
|
+
user_specified_tags: Dict[str, str] = None
|
|
174
|
+
):
|
|
175
|
+
self._run_name = run_name
|
|
176
|
+
|
|
177
|
+
def _reset(self):
|
|
178
|
+
self._run_name = None
|
|
179
|
+
self._data = {}
|
|
180
|
+
self._params = {}
|
|
181
|
+
self._metrics = {}
|
|
182
|
+
|
|
183
|
+
def end_run(self):
|
|
184
|
+
self._reset()
|
|
185
|
+
|
|
186
|
+
def log(self, name: str, value: Any):
|
|
187
|
+
self._data[name] = value
|
|
188
|
+
|
|
189
|
+
def log_param(self, name: str, value: Any):
|
|
190
|
+
self._params[name] = value
|
|
191
|
+
|
|
192
|
+
def log_params(self, params: Dict[str, Any]):
|
|
193
|
+
for name, value in params.items():
|
|
194
|
+
self.log_param(name, value)
|
|
195
|
+
|
|
196
|
+
def log_text(self, name: str, value: str):
|
|
197
|
+
self.log(name, value)
|
|
198
|
+
|
|
199
|
+
def log_metric(self, name: str, value: float, step: int | None = None, metric_metadata: dict | None = None):
|
|
200
|
+
self._metrics[name] = value
|
|
201
|
+
|
|
202
|
+
def log_metrics(self, metrics: Dict[str, float | list], step: int | None = None):
|
|
203
|
+
for name, value in metrics.items():
|
|
204
|
+
self.log_metric(name, value, step)
|
|
205
|
+
|
|
206
|
+
def log_artifact(self, artifact_path: str, save_dir: str = None):
|
|
207
|
+
raise NotImplementedError()
|
|
208
|
+
|
|
209
|
+
def log_figure(self, figure, save_path: str):
|
|
210
|
+
raise NotImplementedError()
|
|
211
|
+
|
|
212
|
+
def set_run_tag(self, name: str, value: str | int | float):
|
|
213
|
+
self._store_tag(name, value, run_name=self.run_name)
|
|
214
|
+
|
|
215
|
+
def set_experiment_tag(self, name: str, value: str | int | float):
|
|
216
|
+
self._store_tag(name, value)
|