toolchemy 0.2.185__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. toolchemy/__main__.py +9 -0
  2. toolchemy/ai/clients/__init__.py +20 -0
  3. toolchemy/ai/clients/common.py +429 -0
  4. toolchemy/ai/clients/dummy_model_client.py +61 -0
  5. toolchemy/ai/clients/factory.py +37 -0
  6. toolchemy/ai/clients/gemini_client.py +48 -0
  7. toolchemy/ai/clients/ollama_client.py +58 -0
  8. toolchemy/ai/clients/openai_client.py +76 -0
  9. toolchemy/ai/clients/pricing.py +66 -0
  10. toolchemy/ai/clients/whisper_client.py +141 -0
  11. toolchemy/ai/prompter.py +124 -0
  12. toolchemy/ai/trackers/__init__.py +5 -0
  13. toolchemy/ai/trackers/common.py +216 -0
  14. toolchemy/ai/trackers/mlflow_tracker.py +221 -0
  15. toolchemy/ai/trackers/neptune_tracker.py +135 -0
  16. toolchemy/db/lightdb.py +260 -0
  17. toolchemy/utils/__init__.py +19 -0
  18. toolchemy/utils/at_exit_collector.py +109 -0
  19. toolchemy/utils/cacher/__init__.py +20 -0
  20. toolchemy/utils/cacher/cacher_diskcache.py +121 -0
  21. toolchemy/utils/cacher/cacher_pickle.py +152 -0
  22. toolchemy/utils/cacher/cacher_shelve.py +196 -0
  23. toolchemy/utils/cacher/common.py +174 -0
  24. toolchemy/utils/datestimes.py +77 -0
  25. toolchemy/utils/locations.py +111 -0
  26. toolchemy/utils/logger.py +76 -0
  27. toolchemy/utils/timer.py +23 -0
  28. toolchemy/utils/utils.py +168 -0
  29. toolchemy/vision/__init__.py +5 -0
  30. toolchemy/vision/caption_overlay.py +77 -0
  31. toolchemy/vision/image.py +89 -0
  32. toolchemy-0.2.185.dist-info/METADATA +25 -0
  33. toolchemy-0.2.185.dist-info/RECORD +36 -0
  34. toolchemy-0.2.185.dist-info/WHEEL +4 -0
  35. toolchemy-0.2.185.dist-info/entry_points.txt +3 -0
  36. toolchemy-0.2.185.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,76 @@
1
+ import time
2
+ from abc import ABC, abstractmethod
3
+ from openai import OpenAI, AzureOpenAI, NOT_GIVEN
4
+ from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
5
+ from typing import Iterable
6
+
7
+ from toolchemy.ai.clients.common import LLMClientBase, ModelConfig, Usage, prepare_chat_messages
8
+
9
+
10
+ class BaseOpenAIClient(LLMClientBase, ABC):
11
+ @property
12
+ @abstractmethod
13
+ def _client(self) -> OpenAI:
14
+ pass
15
+
16
+ def embeddings(self, text: str) -> list[float]:
17
+ response = self._client.embeddings.create(
18
+ input=text,
19
+ model=self.embedding_name
20
+ )
21
+ return response.data[0].embedding
22
+
23
+ def _completion(self, prompt: str, system_prompt: str | None, model_config: ModelConfig | None = None,
24
+ images_base64: list[str] | None = None) -> tuple[str, Usage]:
25
+ messages = self._prepare_chat_messages(prompt=prompt, system_prompt=system_prompt, images_base64=images_base64)
26
+
27
+ duration_time_start = time.time()
28
+ response = self._client.chat.completions.create(
29
+ model=model_config.model_name,
30
+ messages=messages,
31
+ top_p=model_config.top_p,
32
+ )
33
+ duration = time.time() - duration_time_start
34
+
35
+ usage = Usage(input_tokens=response.usage.prompt_tokens, output_tokens=response.usage.completion_tokens,
36
+ duration=duration)
37
+
38
+ return response.choices[0].message.content, usage
39
+
40
+ def _prepare_chat_messages(self, prompt: str, system_prompt: str | None = None, images_base64: list[str] | None = None) -> list[ChatCompletionMessageParam]:
41
+ messages_all = []
42
+ system_prompt = system_prompt or self._system_prompt or NOT_GIVEN
43
+
44
+ if self._keep_chat_session:
45
+ messages_all.extend(self._session_messages)
46
+
47
+ return prepare_chat_messages(prompt=prompt, system_prompt=system_prompt, images_base64=images_base64, messages_history=messages_all)
48
+
49
+
50
+ class OpenAIClient(BaseOpenAIClient):
51
+ def __init__(self, api_key: str, model_name: str = "gpt-3.5-turbo",
52
+ embedding_model_name: str = "text-embedding-3-large", default_model_config: ModelConfig | None = None,
53
+ system_prompt: str | None = None, keep_chat_session: bool = False, no_cache: bool = False):
54
+ super().__init__(default_model_name=model_name, default_embedding_model_name=embedding_model_name,
55
+ default_model_config=default_model_config,
56
+ system_prompt=system_prompt, keep_chat_session=keep_chat_session, disable_cache=no_cache)
57
+ self._openai_client = OpenAI(api_key=api_key)
58
+
59
+ @property
60
+ def _client(self) -> OpenAI:
61
+ return self._openai_client
62
+
63
+
64
+ class AzureOpenAIClient(BaseOpenAIClient):
65
+ def __init__(self, api_key: str, api_endpoint: str, api_version: str, model_name: str = "gpt-3.5-turbo",
66
+ embedding_model_name: str = "text-embedding-3-large", default_model_config: ModelConfig | None = None,
67
+ system_prompt: str | None = None, keep_chat_session: bool = False):
68
+ super().__init__(default_model_name=model_name, default_embedding_model_name=embedding_model_name,
69
+ default_model_config=default_model_config,
70
+ system_prompt=system_prompt, keep_chat_session=keep_chat_session)
71
+ self._openai_client = AzureOpenAI(api_key=api_key, azure_endpoint=api_endpoint, api_version=api_version,
72
+ azure_deployment=model_name)
73
+
74
+ @property
75
+ def _client(self) -> OpenAI:
76
+ return self._openai_client
@@ -0,0 +1,66 @@
1
+ KEY_INPUT_TOKENS = "input_tokens_cost"
2
+ KEY_OUTPUT_TOKENS = "output_tokens_cost"
3
+
4
+
5
+
6
+ class Pricing:
7
+ pricing_per_1_mln = {
8
+ "gpt-5.2": {
9
+ KEY_INPUT_TOKENS: 1.75,
10
+ KEY_OUTPUT_TOKENS: 14.00,
11
+ },
12
+ "gpt-5.2-pro": {
13
+ KEY_INPUT_TOKENS: 21.00,
14
+ KEY_OUTPUT_TOKENS: 168.00,
15
+ },
16
+ "gpt-5-mini": {
17
+ KEY_INPUT_TOKENS: 0.25,
18
+ KEY_OUTPUT_TOKENS: 2.00,
19
+ },
20
+ "gpt-4.1": {
21
+ KEY_INPUT_TOKENS: 3.00,
22
+ KEY_OUTPUT_TOKENS: 12.00,
23
+ },
24
+ "gpt-4.1-mini": {
25
+ KEY_INPUT_TOKENS: 0.8,
26
+ KEY_OUTPUT_TOKENS: 3.2,
27
+ },
28
+ "gpt-4.1-nano": {
29
+ KEY_INPUT_TOKENS: 0.2,
30
+ KEY_OUTPUT_TOKENS: 0.8,
31
+ },
32
+ "o4-mini": {
33
+ KEY_INPUT_TOKENS: 4.00,
34
+ KEY_OUTPUT_TOKENS: 16.00,
35
+ },
36
+ "mistral-small3.2:24b": {
37
+ KEY_INPUT_TOKENS: 0.01,
38
+ KEY_OUTPUT_TOKENS: 0.03,
39
+ },
40
+ "gpt-oss:120b": {
41
+ KEY_INPUT_TOKENS: 0.01,
42
+ KEY_OUTPUT_TOKENS: 0.03,
43
+ },
44
+ "gemma3:27b": {
45
+ KEY_INPUT_TOKENS: 0.01,
46
+ KEY_OUTPUT_TOKENS: 0.03,
47
+ },
48
+ "qwen3:32b-q8_0": {
49
+ KEY_INPUT_TOKENS: 0.01,
50
+ KEY_OUTPUT_TOKENS: 0.03,
51
+ },
52
+ "dummy-model": {
53
+ KEY_INPUT_TOKENS: 0.1,
54
+ KEY_OUTPUT_TOKENS: 1.0,
55
+ }
56
+ }
57
+
58
+ @classmethod
59
+ def estimate(cls, model_name: str, input_tokens: int, output_tokens: int) -> float:
60
+ if model_name not in cls.pricing_per_1_mln:
61
+ raise ValueError(f"Model '{model_name}' not supported for pricing estimation")
62
+
63
+ model_pricing = cls.pricing_per_1_mln[model_name]
64
+ input_cost = model_pricing[KEY_INPUT_TOKENS] * (input_tokens / 1_000_000)
65
+ output_cost = model_pricing[KEY_OUTPUT_TOKENS] * (output_tokens / 1_000_000)
66
+ return input_cost + output_cost
@@ -0,0 +1,141 @@
1
+ import asyncio
2
+ import requests
3
+ import sys
4
+ import os
5
+ import subprocess
6
+ import tempfile
7
+ from wyoming.client import AsyncClient
8
+ from wyoming.audio import AudioChunk, AudioStart, AudioStop, AudioChunkConverter
9
+ from wyoming.asr import Transcript
10
+ from wyoming.ping import Ping
11
+ from wyoming.info import Describe
12
+ import wave
13
+
14
+ from toolchemy.utils.logger import get_logger
15
+
16
+
17
+ class WhisperClient:
18
+ def __init__(self, url: str):
19
+ self._logger = get_logger()
20
+ self._endpoint = url
21
+ self._whisper_client_wyoming = None
22
+ if self._endpoint.startswith("http"):
23
+ if not self._endpoint.endswith("transcribe"):
24
+ if not self._endpoint.endswith("/"):
25
+ self._endpoint += "/"
26
+ self._endpoint += "transcribe"
27
+ elif self._endpoint.startswith("tcp"):
28
+ self._whisper_client_wyoming = AsyncClient.from_uri(self._endpoint)
29
+ else:
30
+ raise ValueError(f"Unknown protocol for the whisper server endpoint: '{self._endpoint}'")
31
+ self._logger.info(f"Whisper client initialized (endpoint: '{self._endpoint}')")
32
+
33
+ def transcribe(self, audio_path: str) -> str:
34
+ transcription = None
35
+
36
+ if self._endpoint.startswith("tcp"):
37
+ transcription = asyncio.run(self._transcribe_wyoming(audio_path))
38
+
39
+ if self._endpoint.startswith("http"):
40
+ transcription = self._transcribe_http(audio_path)
41
+
42
+ if transcription is None:
43
+ raise RuntimeError(f"Transcription failed...")
44
+
45
+ return transcription.strip()
46
+
47
+ def _transcribe_http(self, audio_path: str) -> str:
48
+ if not os.path.exists(audio_path):
49
+ raise ValueError(f"Error: File '{audio_path}' not found.")
50
+
51
+ with open(audio_path, "rb") as audio_file:
52
+ files = {"file": audio_file}
53
+
54
+ self._logger.info(f"Sending '{audio_path}' to Whisper server...")
55
+ response = requests.post(self._endpoint, files=files)
56
+
57
+ if response.status_code == 200:
58
+ result = response.json()
59
+ result_transcription = result.get("text")
60
+ self._logger.info(f"Transcription: '{result_transcription}'")
61
+ return result_transcription
62
+
63
+ err_msg = f"Error: Failed to transcribe. Status Code: {response.status_code}"
64
+ self._logger.error(err_msg)
65
+ raise RuntimeError(err_msg)
66
+
67
+ async def _transcribe_wyoming(self, audio_path: str, audio_rate: int = 16000, audio_width: int = 2,
68
+ audio_channels: int = 1, chunk_size: int = 1024) -> str:
69
+ wav_path = self._convert_to_wav(audio_path, audio_rate=audio_rate, audio_channels=audio_channels)
70
+
71
+ await self._whisper_client_wyoming.connect()
72
+ await self._whisper_client_wyoming.write_event(Ping(text="test").event())
73
+
74
+ await self._whisper_client_wyoming.write_event(Describe().event())
75
+
76
+ info_event = await self._whisper_client_wyoming.read_event()
77
+ self._logger.info(f"info event: {info_event}")
78
+
79
+ with wave.open(wav_path, 'rb') as wav:
80
+ assert wav.getframerate() == audio_rate
81
+ assert wav.getsampwidth() == audio_width
82
+ assert wav.getnchannels() == audio_channels
83
+
84
+ await self._whisper_client_wyoming.write_event(AudioStart(audio_rate, audio_width, audio_channels).event())
85
+
86
+ audio_bytes = wav.readframes(chunk_size)
87
+
88
+ converter = AudioChunkConverter(
89
+ rate=audio_rate, width=audio_width, channels=audio_channels,
90
+ )
91
+
92
+ while audio_bytes:
93
+ chunk = converter.convert(AudioChunk(audio_rate, audio_width, audio_channels, audio_bytes))
94
+ await self._whisper_client_wyoming.write_event(chunk.event())
95
+ audio_bytes = wav.readframes(chunk_size)
96
+
97
+ await self._whisper_client_wyoming.write_event(AudioStop().event())
98
+
99
+ while True:
100
+ event = await asyncio.wait_for(self._whisper_client_wyoming.read_event(), timeout=30)
101
+ if event is None:
102
+ break
103
+ transcript = Transcript.from_event(event)
104
+ if transcript.text:
105
+ self._logger.info(f"Transcription: {transcript.text}")
106
+ break
107
+ await self._whisper_client_wyoming.disconnect()
108
+
109
+ return transcript.text
110
+
111
+ def _convert_to_wav(self, input_path: str, audio_rate: int, audio_channels: int) -> str:
112
+ temp_wav = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
113
+ temp_wav.close()
114
+
115
+ cmd = [
116
+ "ffmpeg",
117
+ "-y",
118
+ "-i", input_path,
119
+ "-ar", str(audio_rate),
120
+ "-ac", str(audio_channels),
121
+ "-f", "wav",
122
+ "-sample_fmt", "s16",
123
+ temp_wav.name
124
+ ]
125
+
126
+ subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
127
+ return temp_wav.name
128
+
129
+
130
+ def main(argv: list):
131
+ if len(argv) < 2:
132
+ raise ValueError("Usage: python transcribe_audio.py <audio_file.mp3/wav>")
133
+
134
+ file_path = argv[1]
135
+ client = WhisperClient(url="tcp://hal:10300")
136
+ transcription = client.transcribe(file_path)
137
+ print(f"> '{transcription}'")
138
+
139
+
140
+ if __name__ == "__main__":
141
+ main(sys.argv)
@@ -0,0 +1,124 @@
1
+ import logging
2
+ import mlflow
3
+ import sys
4
+ import subprocess
5
+ from abc import ABC, abstractmethod
6
+
7
+ from toolchemy.utils.cacher import Cacher, ICacher
8
+ from toolchemy.utils.logger import get_logger
9
+ from toolchemy.utils.locations import Locations
10
+
11
+
12
+ class IPrompter(ABC):
13
+ @abstractmethod
14
+ def render(self, name: str, version: str | dict[str, str | list[str]] | None = None, **variables) -> str:
15
+ pass
16
+
17
+ @abstractmethod
18
+ def template(self, name: str, version: str | dict[str, str | list[str]] | None = None) -> str:
19
+ pass
20
+
21
+ @abstractmethod
22
+ def create_template(self, name: str, template: str, overwrite: bool = False):
23
+ pass
24
+
25
+ @abstractmethod
26
+ def run_studio(self):
27
+ pass
28
+
29
+
30
+ def run_studio():
31
+ prompter = PrompterMLflow()
32
+ prompter.run_studio()
33
+
34
+
35
+ class PrompterBase(IPrompter, ABC):
36
+ def __init__(self, cacher: ICacher | None = None, no_cache: bool = False, log_level: int = logging.INFO):
37
+ self._logger = get_logger(level=log_level)
38
+ self._cacher = cacher or Cacher(disabled=no_cache)
39
+
40
+ def _prompt_version(self, name: str, version_mapping: str | dict[str, str | list[str]] | None = None) -> str | None:
41
+ if version_mapping is None:
42
+ return self._latest_value()
43
+ if isinstance(version_mapping, int):
44
+ return str(version_mapping)
45
+ if isinstance(version_mapping, str):
46
+ return version_mapping
47
+ if name not in version_mapping:
48
+ return None
49
+ if isinstance(version_mapping[name], str):
50
+ return version_mapping[name]
51
+ return version_mapping[name][0]
52
+
53
+ @abstractmethod
54
+ def _latest_value(self) -> str | None:
55
+ pass
56
+
57
+
58
+ class PrompterMLflow(PrompterBase):
59
+ DEFAULT_PROMPT_REGISTRY_NAME = "prompts_mlflow"
60
+
61
+ def __init__(self, registry_store_dir: str | None = None, cacher: ICacher | None = None, no_cache: bool = False, log_level: int = logging.INFO):
62
+ super().__init__(cacher=cacher, no_cache=no_cache, log_level=log_level)
63
+ locations = Locations()
64
+ if registry_store_dir is None:
65
+ registry_store_dir = locations.in_root(self.DEFAULT_PROMPT_REGISTRY_NAME)
66
+ registry_store_dir = locations.abs(registry_store_dir).rstrip("/")
67
+
68
+ self._registry_store_uri = f"sqlite:///{registry_store_dir}/registry.db"
69
+ self._tracking_uri = f"sqlite:///{registry_store_dir}/tracking.db"
70
+ mlflow.set_tracking_uri(self._tracking_uri)
71
+ mlflow.set_registry_uri(self._registry_store_uri)
72
+
73
+ self._logger.info(f"Prompter-MLflow initialized")
74
+ self._logger.info(f"> tracking store uri: {self._tracking_uri}")
75
+ self._logger.info(f"> registry store uri: {self._registry_store_uri}")
76
+
77
+ def render(self, name: str, version: str | dict[str, str | list[str]] | None = None, **variables) -> str:
78
+ prompt_uri = self._build_prompt_uri(name=name, version=version)
79
+
80
+ cache_key = self._cacher.create_cache_key(["prompt_render", prompt_uri], [variables])
81
+ if self._cacher.exists(cache_key):
82
+ return self._cacher.get(cache_key)
83
+
84
+ self._logger.debug(f"Rendering prompt: '{name}' (version: '{version}') -> prompt uri: '{prompt_uri}'")
85
+
86
+ prompt_template = mlflow.genai.load_prompt(prompt_uri)
87
+ prompt = prompt_template.format(**variables)
88
+
89
+ self._cacher.set(cache_key, prompt)
90
+
91
+ return prompt
92
+
93
+ def template(self, name: str, version: str | dict[str, str | list[str]] | None = None) -> str:
94
+ prompt_uri = self._build_prompt_uri(name=name, version=version)
95
+
96
+ cache_key = self._cacher.create_cache_key(["prompt_template", prompt_uri])
97
+ if self._cacher.exists(cache_key):
98
+ return self._cacher.get(cache_key)
99
+
100
+ self._logger.debug(f"Getting prompt template: '{name}' (version: '{version}') -> prompt uri: '{prompt_uri}'")
101
+
102
+ prompt_template = mlflow.genai.load_prompt(prompt_uri)
103
+
104
+ self._cacher.set(cache_key, prompt_template.template)
105
+
106
+ return prompt_template.template
107
+
108
+ def create_template(self, name: str, template: str, overwrite: bool = False):
109
+ if mlflow.genai.load_prompt(name_or_uri=name, allow_missing=True) and not overwrite:
110
+ return
111
+ mlflow.genai.register_prompt(name=name, template=template)
112
+
113
+ def _build_prompt_uri(self, name: str, version: str | int | None = None) -> str:
114
+ prompt_version = self._prompt_version(name, version_mapping=version)
115
+ prompt_uri = f"prompts:/{name}"
116
+ prompt_uri += f"@{prompt_version}" if isinstance(prompt_version, str) and not prompt_version.isdigit() else f"/{prompt_version}"
117
+ return prompt_uri
118
+
119
+ def run_studio(self):
120
+ command = [sys.executable, "-m", "mlflow", "ui", "--registry-store-uri", self._registry_store_uri, "--backend-store-uri", self._tracking_uri]
121
+ sys.exit(subprocess.call(command))
122
+
123
+ def _latest_value(self) -> str | None:
124
+ return "latest"
@@ -0,0 +1,5 @@
1
+ from .common import ITracker, InMemoryTracker
2
+ from .mlflow_tracker import MLFlowTracker
3
+ from .neptune_tracker import NeptuneAITracker
4
+
5
+ __all__ = ["ITracker", "InMemoryTracker", "MLFlowTracker", "NeptuneAITracker"]
@@ -0,0 +1,216 @@
1
+ from abc import ABC, abstractmethod
2
+ import statistics
3
+ from typing import Dict, Any
4
+
5
+ from toolchemy.utils.logger import get_logger
6
+
7
+
8
+ class ITracker(ABC):
9
+ @property
10
+ @abstractmethod
11
+ def experiment_name(self) -> str:
12
+ pass
13
+
14
+ @property
15
+ @abstractmethod
16
+ def run_name(self) -> str:
17
+ pass
18
+
19
+ @abstractmethod
20
+ def start_run(
21
+ self, run_id: str = None,
22
+ run_name: str = None,
23
+ parent_run_id: str = None,
24
+ user_specified_tags: Dict[str, str] = None
25
+ ):
26
+ pass
27
+
28
+ @abstractmethod
29
+ def end_run(self):
30
+ pass
31
+
32
+ @abstractmethod
33
+ def log(self, name: str, value: Any):
34
+ pass
35
+
36
+ @abstractmethod
37
+ def log_param(self, name: str, value: Any):
38
+ pass
39
+
40
+ @abstractmethod
41
+ def log_params(self, params: Dict[str, Any]):
42
+ pass
43
+
44
+ @abstractmethod
45
+ def log_metric(self, name: str, value: float, step: int | None = None, metric_metadata: dict | None = None):
46
+ pass
47
+
48
+ @abstractmethod
49
+ def log_text(self, name: str, value: str):
50
+ pass
51
+
52
+ @abstractmethod
53
+ def log_metrics(self, metrics: Dict[str, float | list], step: int | None = None):
54
+ pass
55
+
56
+ @abstractmethod
57
+ def log_artifact(self, artifact_path: str, save_dir: str = None):
58
+ pass
59
+
60
+ @abstractmethod
61
+ def log_figure(self, figure, save_path: str):
62
+ """
63
+ Parameters:
64
+ figure (matplotlib.figure.Figure, _matplotlib.figure.Figure, _plotly.graph_objects.Figure): plot figure
65
+ save_path (str): the run-relative artifact file path in posixpath format to which
66
+ the figure is saved (e.g. "dir/file.png").
67
+ """
68
+ pass
69
+
70
+ @abstractmethod
71
+ def set_run_tag(self, name: str, value: str | int | float):
72
+ pass
73
+
74
+ @abstractmethod
75
+ def set_experiment_tag(self, name: str, value: str | int | float):
76
+ pass
77
+
78
+ @abstractmethod
79
+ def disable(self):
80
+ pass
81
+
82
+ @abstractmethod
83
+ def get_data(self) -> dict:
84
+ pass
85
+
86
+
87
+ class TrackerBase(ITracker, ABC):
88
+ def __init__(self, experiment_name: str, with_artifact_logging: bool = True, disabled: bool = False):
89
+ self._disabled = disabled
90
+ self._logger = get_logger()
91
+ self._experiment_name = experiment_name
92
+ self._artifact_logging = with_artifact_logging
93
+ self._metrics = {}
94
+ self._params = {}
95
+ self._tags = {
96
+ "experiment": {},
97
+ "runs": {},
98
+ }
99
+
100
+ @property
101
+ def experiment_name(self) -> str:
102
+ return self._experiment_name
103
+
104
+ def get_max_metric_value(self, name: str) -> float:
105
+ return max(self._metrics[name], key=lambda el: el['value'])
106
+
107
+ def get_min_metric_value(self, name: str) -> float:
108
+ return min(self._metrics[name], key=lambda el: el['value'])
109
+
110
+ def get_avg_metric_value(self, name: str) -> float:
111
+ metric_values = [m['value'] for m in self._metrics[name]]
112
+ return statistics.mean(metric_values)
113
+
114
+ def get_data(self) -> dict:
115
+ return {
116
+ "metrics": self._metrics.copy(),
117
+ "params": self._params.copy(),
118
+ "tags": self._tags.copy(),
119
+ }
120
+
121
+ def _store_param(self, name: str, value: Any):
122
+ if self._disabled:
123
+ raise RuntimeError(f"Disabled trackers cannot store params!")
124
+
125
+ self._params[name] = value
126
+
127
+ def _store_tag(self, name: str, value: str | int | float, run_name: str | None = None):
128
+ if run_name is None:
129
+ self._tags["experiment"][name] = value
130
+ return
131
+ if run_name not in self._tags["runs"]:
132
+ self._tags["runs"][run_name] = {}
133
+ self._tags["runs"][run_name][name] = value
134
+
135
+ def _store_metric(self, name: str, value: float, metric_metadata: dict | None = None) -> float:
136
+ if self._disabled:
137
+ raise RuntimeError(f"Disabled trackers cannot store metrics!")
138
+ if name not in self._metrics:
139
+ self._metrics[name] = []
140
+
141
+ if isinstance(value, dict):
142
+ new_entry = value
143
+ else:
144
+ new_entry = {
145
+ 'value': value
146
+ }
147
+
148
+ if metric_metadata:
149
+ new_entry.update(metric_metadata)
150
+
151
+ self._metrics[name] += [new_entry]
152
+
153
+ return new_entry['value']
154
+
155
+ def disable(self):
156
+ self._disabled = True
157
+
158
+
159
+ class InMemoryTracker(TrackerBase):
160
+ def __init__(self, experiment_name: str = "dummy", disabled: bool = False):
161
+ super().__init__(experiment_name=experiment_name, disabled=disabled)
162
+ self._run_name = None
163
+ self._reset()
164
+
165
+ @property
166
+ def run_name(self) -> str:
167
+ return self._run_name
168
+
169
+ def start_run(
170
+ self, run_id: str = None,
171
+ run_name: str = None,
172
+ parent_run_id: str = None,
173
+ user_specified_tags: Dict[str, str] = None
174
+ ):
175
+ self._run_name = run_name
176
+
177
+ def _reset(self):
178
+ self._run_name = None
179
+ self._data = {}
180
+ self._params = {}
181
+ self._metrics = {}
182
+
183
+ def end_run(self):
184
+ self._reset()
185
+
186
+ def log(self, name: str, value: Any):
187
+ self._data[name] = value
188
+
189
+ def log_param(self, name: str, value: Any):
190
+ self._params[name] = value
191
+
192
+ def log_params(self, params: Dict[str, Any]):
193
+ for name, value in params.items():
194
+ self.log_param(name, value)
195
+
196
+ def log_text(self, name: str, value: str):
197
+ self.log(name, value)
198
+
199
+ def log_metric(self, name: str, value: float, step: int | None = None, metric_metadata: dict | None = None):
200
+ self._metrics[name] = value
201
+
202
+ def log_metrics(self, metrics: Dict[str, float | list], step: int | None = None):
203
+ for name, value in metrics.items():
204
+ self.log_metric(name, value, step)
205
+
206
+ def log_artifact(self, artifact_path: str, save_dir: str = None):
207
+ raise NotImplementedError()
208
+
209
+ def log_figure(self, figure, save_path: str):
210
+ raise NotImplementedError()
211
+
212
+ def set_run_tag(self, name: str, value: str | int | float):
213
+ self._store_tag(name, value, run_name=self.run_name)
214
+
215
+ def set_experiment_tag(self, name: str, value: str | int | float):
216
+ self._store_tag(name, value)