kiln-ai 0.19.0__py3-none-any.whl → 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +8 -2
- kiln_ai/adapters/adapter_registry.py +43 -208
- kiln_ai/adapters/chat/chat_formatter.py +8 -12
- kiln_ai/adapters/chat/test_chat_formatter.py +6 -2
- kiln_ai/adapters/chunkers/__init__.py +13 -0
- kiln_ai/adapters/chunkers/base_chunker.py +42 -0
- kiln_ai/adapters/chunkers/chunker_registry.py +16 -0
- kiln_ai/adapters/chunkers/fixed_window_chunker.py +39 -0
- kiln_ai/adapters/chunkers/helpers.py +23 -0
- kiln_ai/adapters/chunkers/test_base_chunker.py +63 -0
- kiln_ai/adapters/chunkers/test_chunker_registry.py +28 -0
- kiln_ai/adapters/chunkers/test_fixed_window_chunker.py +346 -0
- kiln_ai/adapters/chunkers/test_helpers.py +75 -0
- kiln_ai/adapters/data_gen/test_data_gen_task.py +9 -3
- kiln_ai/adapters/docker_model_runner_tools.py +119 -0
- kiln_ai/adapters/embedding/__init__.py +0 -0
- kiln_ai/adapters/embedding/base_embedding_adapter.py +44 -0
- kiln_ai/adapters/embedding/embedding_registry.py +32 -0
- kiln_ai/adapters/embedding/litellm_embedding_adapter.py +199 -0
- kiln_ai/adapters/embedding/test_base_embedding_adapter.py +283 -0
- kiln_ai/adapters/embedding/test_embedding_registry.py +166 -0
- kiln_ai/adapters/embedding/test_litellm_embedding_adapter.py +1149 -0
- kiln_ai/adapters/eval/base_eval.py +2 -2
- kiln_ai/adapters/eval/eval_runner.py +9 -3
- kiln_ai/adapters/eval/g_eval.py +2 -2
- kiln_ai/adapters/eval/test_base_eval.py +2 -4
- kiln_ai/adapters/eval/test_g_eval.py +4 -5
- kiln_ai/adapters/extractors/__init__.py +18 -0
- kiln_ai/adapters/extractors/base_extractor.py +72 -0
- kiln_ai/adapters/extractors/encoding.py +20 -0
- kiln_ai/adapters/extractors/extractor_registry.py +44 -0
- kiln_ai/adapters/extractors/extractor_runner.py +112 -0
- kiln_ai/adapters/extractors/litellm_extractor.py +386 -0
- kiln_ai/adapters/extractors/test_base_extractor.py +244 -0
- kiln_ai/adapters/extractors/test_encoding.py +54 -0
- kiln_ai/adapters/extractors/test_extractor_registry.py +181 -0
- kiln_ai/adapters/extractors/test_extractor_runner.py +181 -0
- kiln_ai/adapters/extractors/test_litellm_extractor.py +1192 -0
- kiln_ai/adapters/fine_tune/__init__.py +1 -1
- kiln_ai/adapters/fine_tune/openai_finetune.py +14 -4
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +2 -2
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +2 -6
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +108 -111
- kiln_ai/adapters/fine_tune/test_together_finetune.py +2 -6
- kiln_ai/adapters/ml_embedding_model_list.py +192 -0
- kiln_ai/adapters/ml_model_list.py +761 -37
- kiln_ai/adapters/model_adapters/base_adapter.py +51 -21
- kiln_ai/adapters/model_adapters/litellm_adapter.py +380 -138
- kiln_ai/adapters/model_adapters/test_base_adapter.py +193 -17
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +407 -2
- kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +1103 -0
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +5 -5
- kiln_ai/adapters/model_adapters/test_structured_output.py +113 -5
- kiln_ai/adapters/ollama_tools.py +69 -12
- kiln_ai/adapters/parsers/__init__.py +1 -1
- kiln_ai/adapters/provider_tools.py +205 -47
- kiln_ai/adapters/rag/deduplication.py +49 -0
- kiln_ai/adapters/rag/progress.py +252 -0
- kiln_ai/adapters/rag/rag_runners.py +844 -0
- kiln_ai/adapters/rag/test_deduplication.py +195 -0
- kiln_ai/adapters/rag/test_progress.py +785 -0
- kiln_ai/adapters/rag/test_rag_runners.py +2376 -0
- kiln_ai/adapters/remote_config.py +80 -8
- kiln_ai/adapters/repair/test_repair_task.py +12 -9
- kiln_ai/adapters/run_output.py +3 -0
- kiln_ai/adapters/test_adapter_registry.py +657 -85
- kiln_ai/adapters/test_docker_model_runner_tools.py +305 -0
- kiln_ai/adapters/test_ml_embedding_model_list.py +429 -0
- kiln_ai/adapters/test_ml_model_list.py +251 -1
- kiln_ai/adapters/test_ollama_tools.py +340 -1
- kiln_ai/adapters/test_prompt_adaptors.py +13 -6
- kiln_ai/adapters/test_prompt_builders.py +1 -1
- kiln_ai/adapters/test_provider_tools.py +254 -8
- kiln_ai/adapters/test_remote_config.py +651 -58
- kiln_ai/adapters/vector_store/__init__.py +1 -0
- kiln_ai/adapters/vector_store/base_vector_store_adapter.py +83 -0
- kiln_ai/adapters/vector_store/lancedb_adapter.py +389 -0
- kiln_ai/adapters/vector_store/test_base_vector_store.py +160 -0
- kiln_ai/adapters/vector_store/test_lancedb_adapter.py +1841 -0
- kiln_ai/adapters/vector_store/test_vector_store_registry.py +199 -0
- kiln_ai/adapters/vector_store/vector_store_registry.py +33 -0
- kiln_ai/datamodel/__init__.py +39 -34
- kiln_ai/datamodel/basemodel.py +170 -1
- kiln_ai/datamodel/chunk.py +158 -0
- kiln_ai/datamodel/datamodel_enums.py +28 -0
- kiln_ai/datamodel/embedding.py +64 -0
- kiln_ai/datamodel/eval.py +1 -1
- kiln_ai/datamodel/external_tool_server.py +298 -0
- kiln_ai/datamodel/extraction.py +303 -0
- kiln_ai/datamodel/json_schema.py +25 -10
- kiln_ai/datamodel/project.py +40 -1
- kiln_ai/datamodel/rag.py +79 -0
- kiln_ai/datamodel/registry.py +0 -15
- kiln_ai/datamodel/run_config.py +62 -0
- kiln_ai/datamodel/task.py +2 -77
- kiln_ai/datamodel/task_output.py +6 -1
- kiln_ai/datamodel/task_run.py +41 -0
- kiln_ai/datamodel/test_attachment.py +649 -0
- kiln_ai/datamodel/test_basemodel.py +4 -4
- kiln_ai/datamodel/test_chunk_models.py +317 -0
- kiln_ai/datamodel/test_dataset_split.py +1 -1
- kiln_ai/datamodel/test_embedding_models.py +448 -0
- kiln_ai/datamodel/test_eval_model.py +6 -6
- kiln_ai/datamodel/test_example_models.py +175 -0
- kiln_ai/datamodel/test_external_tool_server.py +691 -0
- kiln_ai/datamodel/test_extraction_chunk.py +206 -0
- kiln_ai/datamodel/test_extraction_model.py +470 -0
- kiln_ai/datamodel/test_rag.py +641 -0
- kiln_ai/datamodel/test_registry.py +8 -3
- kiln_ai/datamodel/test_task.py +15 -47
- kiln_ai/datamodel/test_tool_id.py +320 -0
- kiln_ai/datamodel/test_vector_store.py +320 -0
- kiln_ai/datamodel/tool_id.py +105 -0
- kiln_ai/datamodel/vector_store.py +141 -0
- kiln_ai/tools/__init__.py +8 -0
- kiln_ai/tools/base_tool.py +82 -0
- kiln_ai/tools/built_in_tools/__init__.py +13 -0
- kiln_ai/tools/built_in_tools/math_tools.py +124 -0
- kiln_ai/tools/built_in_tools/test_math_tools.py +204 -0
- kiln_ai/tools/mcp_server_tool.py +95 -0
- kiln_ai/tools/mcp_session_manager.py +246 -0
- kiln_ai/tools/rag_tools.py +157 -0
- kiln_ai/tools/test_base_tools.py +199 -0
- kiln_ai/tools/test_mcp_server_tool.py +457 -0
- kiln_ai/tools/test_mcp_session_manager.py +1585 -0
- kiln_ai/tools/test_rag_tools.py +848 -0
- kiln_ai/tools/test_tool_registry.py +562 -0
- kiln_ai/tools/tool_registry.py +85 -0
- kiln_ai/utils/__init__.py +3 -0
- kiln_ai/utils/async_job_runner.py +62 -17
- kiln_ai/utils/config.py +24 -2
- kiln_ai/utils/env.py +15 -0
- kiln_ai/utils/filesystem.py +14 -0
- kiln_ai/utils/filesystem_cache.py +60 -0
- kiln_ai/utils/litellm.py +94 -0
- kiln_ai/utils/lock.py +100 -0
- kiln_ai/utils/mime_type.py +38 -0
- kiln_ai/utils/open_ai_types.py +94 -0
- kiln_ai/utils/pdf_utils.py +38 -0
- kiln_ai/utils/project_utils.py +17 -0
- kiln_ai/utils/test_async_job_runner.py +151 -35
- kiln_ai/utils/test_config.py +138 -1
- kiln_ai/utils/test_env.py +142 -0
- kiln_ai/utils/test_filesystem_cache.py +316 -0
- kiln_ai/utils/test_litellm.py +206 -0
- kiln_ai/utils/test_lock.py +185 -0
- kiln_ai/utils/test_mime_type.py +66 -0
- kiln_ai/utils/test_open_ai_types.py +131 -0
- kiln_ai/utils/test_pdf_utils.py +73 -0
- kiln_ai/utils/test_uuid.py +111 -0
- kiln_ai/utils/test_validation.py +524 -0
- kiln_ai/utils/uuid.py +9 -0
- kiln_ai/utils/validation.py +90 -0
- {kiln_ai-0.19.0.dist-info → kiln_ai-0.21.0.dist-info}/METADATA +12 -5
- kiln_ai-0.21.0.dist-info/RECORD +211 -0
- kiln_ai-0.19.0.dist-info/RECORD +0 -115
- {kiln_ai-0.19.0.dist-info → kiln_ai-0.21.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.19.0.dist-info → kiln_ai-0.21.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import logging
|
|
3
3
|
from dataclasses import dataclass
|
|
4
|
-
from typing import AsyncGenerator, Awaitable, Callable, List, TypeVar
|
|
4
|
+
from typing import AsyncGenerator, Awaitable, Callable, Generic, List, TypeVar
|
|
5
5
|
|
|
6
6
|
logger = logging.getLogger(__name__)
|
|
7
7
|
|
|
@@ -15,29 +15,66 @@ class Progress:
|
|
|
15
15
|
errors: int
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
class
|
|
19
|
-
def
|
|
18
|
+
class AsyncJobRunnerObserver(Generic[T]):
|
|
19
|
+
async def on_error(self, job: T, error: Exception):
|
|
20
|
+
"""
|
|
21
|
+
Called when a job raises an unhandled exception.
|
|
22
|
+
"""
|
|
23
|
+
pass
|
|
24
|
+
|
|
25
|
+
async def on_success(self, job: T):
|
|
26
|
+
"""
|
|
27
|
+
Called when a job completes successfully.
|
|
28
|
+
"""
|
|
29
|
+
pass
|
|
30
|
+
|
|
31
|
+
async def on_job_start(self, job: T):
|
|
32
|
+
"""
|
|
33
|
+
Called when a job starts.
|
|
34
|
+
"""
|
|
35
|
+
pass
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class AsyncJobRunner(Generic[T]):
|
|
39
|
+
def __init__(
|
|
40
|
+
self,
|
|
41
|
+
jobs: List[T],
|
|
42
|
+
run_job_fn: Callable[[T], Awaitable[bool]],
|
|
43
|
+
concurrency: int = 1,
|
|
44
|
+
observers: List[AsyncJobRunnerObserver[T]] | None = None,
|
|
45
|
+
):
|
|
20
46
|
if concurrency < 1:
|
|
21
47
|
raise ValueError("concurrency must be ≥ 1")
|
|
22
48
|
self.concurrency = concurrency
|
|
49
|
+
self.jobs = jobs
|
|
50
|
+
self.run_job_fn = run_job_fn
|
|
51
|
+
self.observers = observers or []
|
|
23
52
|
|
|
24
|
-
async def
|
|
25
|
-
self
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
53
|
+
async def notify_error(self, job: T, error: Exception):
|
|
54
|
+
for observer in self.observers:
|
|
55
|
+
await observer.on_error(job, error)
|
|
56
|
+
|
|
57
|
+
async def notify_success(self, job: T):
|
|
58
|
+
for observer in self.observers:
|
|
59
|
+
await observer.on_success(job)
|
|
60
|
+
|
|
61
|
+
async def notify_job_start(self, job: T):
|
|
62
|
+
for observer in self.observers:
|
|
63
|
+
await observer.on_job_start(job)
|
|
64
|
+
|
|
65
|
+
async def run(self) -> AsyncGenerator[Progress, None]:
|
|
29
66
|
"""
|
|
30
67
|
Runs the jobs with parallel workers and yields progress updates.
|
|
31
68
|
"""
|
|
32
69
|
complete = 0
|
|
33
70
|
errors = 0
|
|
34
|
-
total = len(jobs)
|
|
71
|
+
total = len(self.jobs)
|
|
35
72
|
|
|
36
73
|
# Send initial status
|
|
37
74
|
yield Progress(complete=complete, total=total, errors=errors)
|
|
38
75
|
|
|
39
76
|
worker_queue: asyncio.Queue[T] = asyncio.Queue()
|
|
40
|
-
for job in jobs:
|
|
77
|
+
for job in self.jobs:
|
|
41
78
|
worker_queue.put_nowait(job)
|
|
42
79
|
|
|
43
80
|
# simple status queue to return progress. True=success, False=error
|
|
@@ -46,7 +83,7 @@ class AsyncJobRunner:
|
|
|
46
83
|
workers = []
|
|
47
84
|
for _ in range(self.concurrency):
|
|
48
85
|
task = asyncio.create_task(
|
|
49
|
-
self._run_worker(worker_queue, status_queue,
|
|
86
|
+
self._run_worker(worker_queue, status_queue, self.run_job_fn),
|
|
50
87
|
)
|
|
51
88
|
workers.append(task)
|
|
52
89
|
|
|
@@ -64,7 +101,11 @@ class AsyncJobRunner:
|
|
|
64
101
|
else:
|
|
65
102
|
errors += 1
|
|
66
103
|
|
|
67
|
-
yield Progress(
|
|
104
|
+
yield Progress(
|
|
105
|
+
complete=complete,
|
|
106
|
+
total=total,
|
|
107
|
+
errors=errors,
|
|
108
|
+
)
|
|
68
109
|
except asyncio.TimeoutError:
|
|
69
110
|
# Timeout is expected, just continue to recheck worker status
|
|
70
111
|
# Don't love this but beats sentinels for reliability
|
|
@@ -82,7 +123,7 @@ class AsyncJobRunner:
|
|
|
82
123
|
self,
|
|
83
124
|
worker_queue: asyncio.Queue[T],
|
|
84
125
|
status_queue: asyncio.Queue[bool],
|
|
85
|
-
|
|
126
|
+
run_job_fn: Callable[[T], Awaitable[bool]],
|
|
86
127
|
):
|
|
87
128
|
while True:
|
|
88
129
|
try:
|
|
@@ -92,13 +133,17 @@ class AsyncJobRunner:
|
|
|
92
133
|
break
|
|
93
134
|
|
|
94
135
|
try:
|
|
95
|
-
|
|
96
|
-
|
|
136
|
+
await self.notify_job_start(job)
|
|
137
|
+
result = await run_job_fn(job)
|
|
138
|
+
if result:
|
|
139
|
+
await self.notify_success(job)
|
|
140
|
+
except Exception as e:
|
|
97
141
|
logger.error("Job failed to complete", exc_info=True)
|
|
98
|
-
|
|
142
|
+
await self.notify_error(job, e)
|
|
143
|
+
result = False
|
|
99
144
|
|
|
100
145
|
try:
|
|
101
|
-
await status_queue.put(
|
|
146
|
+
await status_queue.put(result)
|
|
102
147
|
except Exception:
|
|
103
148
|
logger.error("Failed to enqueue status for job", exc_info=True)
|
|
104
149
|
finally:
|
kiln_ai/utils/config.py
CHANGED
|
@@ -6,6 +6,9 @@ from typing import Any, Callable, Dict, List, Optional
|
|
|
6
6
|
|
|
7
7
|
import yaml
|
|
8
8
|
|
|
9
|
+
# Configuration keys
|
|
10
|
+
MCP_SECRETS_KEY = "mcp_secrets"
|
|
11
|
+
|
|
9
12
|
|
|
10
13
|
class ConfigProperty:
|
|
11
14
|
def __init__(
|
|
@@ -54,6 +57,10 @@ class Config:
|
|
|
54
57
|
str,
|
|
55
58
|
env_var="OLLAMA_BASE_URL",
|
|
56
59
|
),
|
|
60
|
+
"docker_model_runner_base_url": ConfigProperty(
|
|
61
|
+
str,
|
|
62
|
+
env_var="DOCKER_MODEL_RUNNER_BASE_URL",
|
|
63
|
+
),
|
|
57
64
|
"bedrock_access_key": ConfigProperty(
|
|
58
65
|
str,
|
|
59
66
|
env_var="AWS_ACCESS_KEY_ID",
|
|
@@ -147,6 +154,21 @@ class Config:
|
|
|
147
154
|
env_var="CEREBRAS_API_KEY",
|
|
148
155
|
sensitive=True,
|
|
149
156
|
),
|
|
157
|
+
"enable_demo_tools": ConfigProperty(
|
|
158
|
+
bool,
|
|
159
|
+
env_var="ENABLE_DEMO_TOOLS",
|
|
160
|
+
default=False,
|
|
161
|
+
),
|
|
162
|
+
# Allow the user to set the path to lookup MCP server commands, like npx.
|
|
163
|
+
"custom_mcp_path": ConfigProperty(
|
|
164
|
+
str,
|
|
165
|
+
env_var="CUSTOM_MCP_PATH",
|
|
166
|
+
),
|
|
167
|
+
# Allow the user to set secrets for MCP servers, the key is mcp_server_id::key_name
|
|
168
|
+
MCP_SECRETS_KEY: ConfigProperty(
|
|
169
|
+
dict[str, str],
|
|
170
|
+
sensitive=True,
|
|
171
|
+
),
|
|
150
172
|
}
|
|
151
173
|
self._lock = threading.Lock()
|
|
152
174
|
self._settings = self.load_settings()
|
|
@@ -199,14 +221,14 @@ class Config:
|
|
|
199
221
|
raise AttributeError(f"Config has no attribute '{name}'")
|
|
200
222
|
|
|
201
223
|
@classmethod
|
|
202
|
-
def settings_dir(cls, create=True):
|
|
224
|
+
def settings_dir(cls, create=True) -> str:
|
|
203
225
|
settings_dir = os.path.join(Path.home(), ".kiln_ai")
|
|
204
226
|
if create and not os.path.exists(settings_dir):
|
|
205
227
|
os.makedirs(settings_dir)
|
|
206
228
|
return settings_dir
|
|
207
229
|
|
|
208
230
|
@classmethod
|
|
209
|
-
def settings_path(cls, create=True):
|
|
231
|
+
def settings_path(cls, create=True) -> str:
|
|
210
232
|
settings_dir = cls.settings_dir(create)
|
|
211
233
|
return os.path.join(settings_dir, "settings.yaml")
|
|
212
234
|
|
kiln_ai/utils/env.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from contextlib import contextmanager
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
@contextmanager
|
|
6
|
+
def temporary_env(var_name: str, value: str):
|
|
7
|
+
old_value = os.environ.get(var_name)
|
|
8
|
+
os.environ[var_name] = value
|
|
9
|
+
try:
|
|
10
|
+
yield
|
|
11
|
+
finally:
|
|
12
|
+
if old_value is None:
|
|
13
|
+
os.environ.pop(var_name, None) # remove if it did not exist before
|
|
14
|
+
else:
|
|
15
|
+
os.environ[var_name] = old_value
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import subprocess
|
|
3
|
+
import sys
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def open_folder(path: str | Path) -> None:
|
|
8
|
+
log_dir = os.path.dirname(path)
|
|
9
|
+
if sys.platform.startswith("darwin"):
|
|
10
|
+
subprocess.run(["open", log_dir], check=True)
|
|
11
|
+
elif sys.platform.startswith("win"):
|
|
12
|
+
os.startfile(log_dir) # type: ignore[attr-defined]
|
|
13
|
+
else:
|
|
14
|
+
subprocess.run(["xdg-open", log_dir], check=True)
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import tempfile
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
import anyio
|
|
6
|
+
|
|
7
|
+
from kiln_ai.datamodel.basemodel import name_validator
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class FilesystemCache:
|
|
13
|
+
def __init__(self, path: Path):
|
|
14
|
+
self.cache_dir_path = path
|
|
15
|
+
|
|
16
|
+
def validate_key(self, key: str) -> None:
|
|
17
|
+
# throws if invalid
|
|
18
|
+
name_validator(min_length=1, max_length=120)(key)
|
|
19
|
+
|
|
20
|
+
def get_path(self, key: str) -> Path:
|
|
21
|
+
self.validate_key(key)
|
|
22
|
+
return self.cache_dir_path / key
|
|
23
|
+
|
|
24
|
+
async def get(self, key: str) -> bytes | None:
|
|
25
|
+
# check if the file exists - don't need to validate the key
|
|
26
|
+
# worst case we just return None
|
|
27
|
+
if not self.get_path(key).exists():
|
|
28
|
+
return None
|
|
29
|
+
|
|
30
|
+
# we don't want to raise because of internal cache corruption issues
|
|
31
|
+
try:
|
|
32
|
+
return await anyio.Path(self.get_path(key)).read_bytes()
|
|
33
|
+
except Exception:
|
|
34
|
+
logger.error(f"Error reading file {self.get_path(key)}", exc_info=True)
|
|
35
|
+
return None
|
|
36
|
+
|
|
37
|
+
async def set(self, key: str, value: bytes) -> Path:
|
|
38
|
+
logger.debug(f"Caching {key} at {self.get_path(key)}")
|
|
39
|
+
self.validate_key(key)
|
|
40
|
+
path = self.get_path(key)
|
|
41
|
+
await anyio.Path(path).write_bytes(value)
|
|
42
|
+
return path
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class TemporaryFilesystemCache:
|
|
46
|
+
_shared_instance = None
|
|
47
|
+
|
|
48
|
+
def __init__(self):
|
|
49
|
+
self._cache_temp_dir = tempfile.mkdtemp(prefix="kiln_cache_")
|
|
50
|
+
self.filesystem_cache = FilesystemCache(path=Path(self._cache_temp_dir))
|
|
51
|
+
|
|
52
|
+
logger.debug(
|
|
53
|
+
f"Created temporary filesystem cache directory: {self._cache_temp_dir}"
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
@classmethod
|
|
57
|
+
def shared(cls) -> FilesystemCache:
|
|
58
|
+
if cls._shared_instance is None:
|
|
59
|
+
cls._shared_instance = cls()
|
|
60
|
+
return cls._shared_instance.filesystem_cache
|
kiln_ai/utils/litellm.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
from kiln_ai.adapters.ml_embedding_model_list import KilnEmbeddingModelProvider
|
|
4
|
+
from kiln_ai.adapters.ml_model_list import KilnModelProvider
|
|
5
|
+
from kiln_ai.datamodel.datamodel_enums import ModelProviderName
|
|
6
|
+
from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class LitellmProviderInfo:
|
|
11
|
+
# The name of the provider, as it appears in litellm
|
|
12
|
+
provider_name: str
|
|
13
|
+
# Whether the provider is custom - e.g. custom models, ollama, fine tunes, and custom registry models
|
|
14
|
+
is_custom: bool
|
|
15
|
+
# The model ID slug to use in litellm
|
|
16
|
+
litellm_model_id: str
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def get_litellm_provider_info(
|
|
20
|
+
model_provider: KilnEmbeddingModelProvider | KilnModelProvider,
|
|
21
|
+
) -> LitellmProviderInfo:
|
|
22
|
+
"""
|
|
23
|
+
Maps a Kiln model provider to a litellm provider.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
model_provider: The model provider to get litellm provider info for
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
LitellmProviderInfo containing the provider name and whether it's custom
|
|
30
|
+
"""
|
|
31
|
+
if not model_provider.model_id:
|
|
32
|
+
raise ValueError("Model ID is required for OpenAI compatible models")
|
|
33
|
+
|
|
34
|
+
litellm_provider_name: str | None = None
|
|
35
|
+
is_custom = False
|
|
36
|
+
match model_provider.name:
|
|
37
|
+
case ModelProviderName.openrouter:
|
|
38
|
+
litellm_provider_name = "openrouter"
|
|
39
|
+
case ModelProviderName.openai:
|
|
40
|
+
litellm_provider_name = "openai"
|
|
41
|
+
case ModelProviderName.groq:
|
|
42
|
+
litellm_provider_name = "groq"
|
|
43
|
+
case ModelProviderName.anthropic:
|
|
44
|
+
litellm_provider_name = "anthropic"
|
|
45
|
+
case ModelProviderName.ollama:
|
|
46
|
+
# We don't let litellm use the Ollama API and muck with our requests. We use Ollama's OpenAI compatible API.
|
|
47
|
+
# This is because we're setting detailed features like response_format=json_schema and want lower level control.
|
|
48
|
+
is_custom = True
|
|
49
|
+
case ModelProviderName.docker_model_runner:
|
|
50
|
+
# Docker Model Runner uses OpenAI-compatible API, similar to Ollama
|
|
51
|
+
# We want direct control over the requests for features like response_format=json_schema
|
|
52
|
+
is_custom = True
|
|
53
|
+
case ModelProviderName.gemini_api:
|
|
54
|
+
litellm_provider_name = "gemini"
|
|
55
|
+
case ModelProviderName.fireworks_ai:
|
|
56
|
+
litellm_provider_name = "fireworks_ai"
|
|
57
|
+
case ModelProviderName.amazon_bedrock:
|
|
58
|
+
litellm_provider_name = "bedrock"
|
|
59
|
+
case ModelProviderName.azure_openai:
|
|
60
|
+
litellm_provider_name = "azure"
|
|
61
|
+
case ModelProviderName.huggingface:
|
|
62
|
+
litellm_provider_name = "huggingface"
|
|
63
|
+
case ModelProviderName.vertex:
|
|
64
|
+
litellm_provider_name = "vertex_ai"
|
|
65
|
+
case ModelProviderName.together_ai:
|
|
66
|
+
litellm_provider_name = "together_ai"
|
|
67
|
+
case ModelProviderName.cerebras:
|
|
68
|
+
litellm_provider_name = "cerebras"
|
|
69
|
+
case ModelProviderName.siliconflow_cn:
|
|
70
|
+
is_custom = True
|
|
71
|
+
case ModelProviderName.openai_compatible:
|
|
72
|
+
is_custom = True
|
|
73
|
+
case ModelProviderName.kiln_custom_registry:
|
|
74
|
+
is_custom = True
|
|
75
|
+
case ModelProviderName.kiln_fine_tune:
|
|
76
|
+
is_custom = True
|
|
77
|
+
case _:
|
|
78
|
+
raise_exhaustive_enum_error(model_provider.name)
|
|
79
|
+
|
|
80
|
+
if is_custom:
|
|
81
|
+
# Use openai as it's only used for format, not url
|
|
82
|
+
litellm_provider_name = "openai"
|
|
83
|
+
|
|
84
|
+
# Shouldn't be possible but keep type checker happy
|
|
85
|
+
if litellm_provider_name is None:
|
|
86
|
+
raise ValueError(
|
|
87
|
+
f"Provider name could not lookup valid litellm provider ID {model_provider.model_id}"
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
return LitellmProviderInfo(
|
|
91
|
+
provider_name=litellm_provider_name,
|
|
92
|
+
is_custom=is_custom,
|
|
93
|
+
litellm_model_id=f"{litellm_provider_name}/{model_provider.model_id}",
|
|
94
|
+
)
|
kiln_ai/utils/lock.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from contextlib import asynccontextmanager
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import Dict, Hashable
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclass
|
|
8
|
+
class _Entry:
|
|
9
|
+
lock: asyncio.Lock = field(default_factory=asyncio.Lock)
|
|
10
|
+
waiters: int = 0 # tasks waiting to acquire
|
|
11
|
+
holders: int = 0 # 0 or 1 for a mutex
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class AsyncLockManager:
|
|
15
|
+
"""
|
|
16
|
+
A per-key asyncio lock manager that automatically cleans up locks when they're no longer needed.
|
|
17
|
+
|
|
18
|
+
Usage:
|
|
19
|
+
locks = AsyncLockManager()
|
|
20
|
+
|
|
21
|
+
async with locks.acquire("user:123"):
|
|
22
|
+
# critical section for "user:123"
|
|
23
|
+
...
|
|
24
|
+
|
|
25
|
+
The manager removes a key when there are no holders and no waiters.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(self) -> None:
|
|
29
|
+
# Protects the _locks dict and bookkeeping counters.
|
|
30
|
+
self._mu = asyncio.Lock()
|
|
31
|
+
self._locks: Dict[Hashable, _Entry] = {}
|
|
32
|
+
|
|
33
|
+
@asynccontextmanager
|
|
34
|
+
async def acquire(self, key: Hashable, *, timeout: float | None = None):
|
|
35
|
+
"""
|
|
36
|
+
Acquire the lock for `key` as an async context manager.
|
|
37
|
+
|
|
38
|
+
- `timeout`: optional seconds to wait; raises TimeoutError on expiry.
|
|
39
|
+
"""
|
|
40
|
+
# Phase 1: register as a waiter and get/create the entry (under manager mutex).
|
|
41
|
+
async with self._mu:
|
|
42
|
+
entry = self._locks.get(key)
|
|
43
|
+
if entry is None:
|
|
44
|
+
entry = self._locks[key] = _Entry()
|
|
45
|
+
entry.waiters += 1
|
|
46
|
+
|
|
47
|
+
# Phase 2: wait on the per-key lock (outside manager mutex).
|
|
48
|
+
try:
|
|
49
|
+
if timeout is None:
|
|
50
|
+
await entry.lock.acquire()
|
|
51
|
+
else:
|
|
52
|
+
# Manual timeout to keep compatibility across Python versions.
|
|
53
|
+
await asyncio.wait_for(entry.lock.acquire(), timeout=timeout)
|
|
54
|
+
|
|
55
|
+
# Phase 3: update counters: became a holder.
|
|
56
|
+
async with self._mu:
|
|
57
|
+
entry.waiters -= 1
|
|
58
|
+
entry.holders += 1
|
|
59
|
+
|
|
60
|
+
try:
|
|
61
|
+
yield # critical section
|
|
62
|
+
finally:
|
|
63
|
+
# Phase 4: release holder and maybe cleanup.
|
|
64
|
+
entry.lock.release()
|
|
65
|
+
async with self._mu:
|
|
66
|
+
entry.holders -= 1
|
|
67
|
+
# Remove the entry if fully idle.
|
|
68
|
+
if entry.waiters == 0 and entry.holders == 0:
|
|
69
|
+
# Double-check we still point to same object (paranoia/race safety).
|
|
70
|
+
if self._locks.get(key) is entry:
|
|
71
|
+
del self._locks[key]
|
|
72
|
+
|
|
73
|
+
except asyncio.TimeoutError:
|
|
74
|
+
# Timed out while waiting; undo waiter count and maybe cleanup.
|
|
75
|
+
async with self._mu:
|
|
76
|
+
entry.waiters -= 1
|
|
77
|
+
if entry.waiters == 0 and entry.holders == 0:
|
|
78
|
+
if self._locks.get(key) is entry:
|
|
79
|
+
del self._locks[key]
|
|
80
|
+
raise
|
|
81
|
+
except asyncio.CancelledError:
|
|
82
|
+
# Cancelled while waiting; same cleanup as timeout.
|
|
83
|
+
async with self._mu:
|
|
84
|
+
entry.waiters -= 1
|
|
85
|
+
if entry.waiters == 0 and entry.holders == 0:
|
|
86
|
+
if self._locks.get(key) is entry:
|
|
87
|
+
del self._locks[key]
|
|
88
|
+
raise
|
|
89
|
+
|
|
90
|
+
# Optional: expose a snapshot for metrics/debugging
|
|
91
|
+
async def snapshot(self) -> Dict[Hashable, dict]:
|
|
92
|
+
async with self._mu:
|
|
93
|
+
return {
|
|
94
|
+
k: {"waiters": e.waiters, "holders": e.holders}
|
|
95
|
+
for k, e in self._locks.items()
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
# callers should use this global instance instead of creating their own
|
|
100
|
+
shared_async_lock_manager = AsyncLockManager()
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
import mimetypes
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def guess_mime_type(filename: str) -> str | None:
|
|
5
|
+
filename_normalized = filename.lower()
|
|
6
|
+
|
|
7
|
+
# we override the mimetypes.guess_type for some common cases
|
|
8
|
+
# because it does not handle them correctly
|
|
9
|
+
if filename_normalized.endswith(".mov"):
|
|
10
|
+
return "video/quicktime"
|
|
11
|
+
elif filename_normalized.endswith(".mp3"):
|
|
12
|
+
return "audio/mpeg"
|
|
13
|
+
elif filename_normalized.endswith(".wav"):
|
|
14
|
+
return "audio/wav"
|
|
15
|
+
elif filename_normalized.endswith(".mp4"):
|
|
16
|
+
return "video/mp4"
|
|
17
|
+
|
|
18
|
+
mime_type, _ = mimetypes.guess_type(filename_normalized)
|
|
19
|
+
return mime_type
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def guess_extension(mime_type: str) -> str | None:
|
|
23
|
+
mapping = {
|
|
24
|
+
"application/pdf": ".pdf",
|
|
25
|
+
"image/png": ".png",
|
|
26
|
+
"video/mp4": ".mp4",
|
|
27
|
+
"audio/ogg": ".ogg",
|
|
28
|
+
"text/markdown": ".md",
|
|
29
|
+
"text/plain": ".txt",
|
|
30
|
+
"text/html": ".html",
|
|
31
|
+
"text/csv": ".csv",
|
|
32
|
+
"image/jpeg": ".jpeg",
|
|
33
|
+
"image/jpg": ".jpeg",
|
|
34
|
+
"audio/mpeg": ".mp3",
|
|
35
|
+
"audio/wav": ".wav",
|
|
36
|
+
"video/quicktime": ".mov",
|
|
37
|
+
}
|
|
38
|
+
return mapping.get(mime_type)
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Wrapper for OpenAI types to make them compatible with Pydantic.
|
|
3
|
+
|
|
4
|
+
Pydantic doesn't support Iterable[T] well, so we use List[T] instead for tool_calls,
|
|
5
|
+
https://github.com/pydantic/pydantic/issues/9541
|
|
6
|
+
|
|
7
|
+
Otherwise we are using OpenAI SDK types directly.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from typing import (
|
|
11
|
+
Iterable,
|
|
12
|
+
List,
|
|
13
|
+
Literal,
|
|
14
|
+
Optional,
|
|
15
|
+
TypeAlias,
|
|
16
|
+
Union,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
from openai.types.chat import (
|
|
20
|
+
ChatCompletionDeveloperMessageParam,
|
|
21
|
+
ChatCompletionFunctionMessageParam,
|
|
22
|
+
ChatCompletionMessageToolCallParam,
|
|
23
|
+
ChatCompletionSystemMessageParam,
|
|
24
|
+
ChatCompletionToolMessageParam,
|
|
25
|
+
ChatCompletionUserMessageParam,
|
|
26
|
+
)
|
|
27
|
+
from openai.types.chat.chat_completion_assistant_message_param import (
|
|
28
|
+
Audio,
|
|
29
|
+
ContentArrayOfContentPart,
|
|
30
|
+
FunctionCall,
|
|
31
|
+
)
|
|
32
|
+
from typing_extensions import Required, TypedDict
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class ChatCompletionAssistantMessageParamWrapper(TypedDict, total=False):
|
|
36
|
+
"""
|
|
37
|
+
Almost exact copy of ChatCompletionAssistantMessageParam, but two changes.
|
|
38
|
+
|
|
39
|
+
First change: List[T] instead of Iterable[T] for tool_calls. Addresses pydantic issue.
|
|
40
|
+
https://github.com/pydantic/pydantic/issues/9541
|
|
41
|
+
|
|
42
|
+
Second change: Add reasoning_content to the message. A LiteLLM property for reasoning data.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
role: Required[Literal["assistant"]]
|
|
46
|
+
"""The role of the messages author, in this case `assistant`."""
|
|
47
|
+
|
|
48
|
+
audio: Optional[Audio]
|
|
49
|
+
"""Data about a previous audio response from the model.
|
|
50
|
+
|
|
51
|
+
[Learn more](https://platform.openai.com/docs/guides/audio).
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
content: Union[str, Iterable[ContentArrayOfContentPart], None]
|
|
55
|
+
"""The contents of the assistant message.
|
|
56
|
+
|
|
57
|
+
Required unless `tool_calls` or `function_call` is specified.
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
reasoning_content: Optional[str]
|
|
61
|
+
"""The reasoning content of the assistant message.
|
|
62
|
+
|
|
63
|
+
A LiteLLM property for reasoning data: https://docs.litellm.ai/docs/reasoning_content
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
function_call: Optional[FunctionCall]
|
|
67
|
+
"""Deprecated and replaced by `tool_calls`.
|
|
68
|
+
|
|
69
|
+
The name and arguments of a function that should be called, as generated by the
|
|
70
|
+
model.
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
name: str
|
|
74
|
+
"""An optional name for the participant.
|
|
75
|
+
|
|
76
|
+
Provides the model information to differentiate between participants of the same
|
|
77
|
+
role.
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
refusal: Optional[str]
|
|
81
|
+
"""The refusal message by the assistant."""
|
|
82
|
+
|
|
83
|
+
tool_calls: List[ChatCompletionMessageToolCallParam]
|
|
84
|
+
"""The tool calls generated by the model, such as function calls."""
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
ChatCompletionMessageParam: TypeAlias = Union[
|
|
88
|
+
ChatCompletionDeveloperMessageParam,
|
|
89
|
+
ChatCompletionSystemMessageParam,
|
|
90
|
+
ChatCompletionUserMessageParam,
|
|
91
|
+
ChatCompletionAssistantMessageParamWrapper,
|
|
92
|
+
ChatCompletionToolMessageParam,
|
|
93
|
+
ChatCompletionFunctionMessageParam,
|
|
94
|
+
]
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Utilities for working with PDF files.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import tempfile
|
|
7
|
+
from contextlib import asynccontextmanager
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import AsyncGenerator
|
|
10
|
+
|
|
11
|
+
from pypdf import PdfReader, PdfWriter
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@asynccontextmanager
|
|
15
|
+
async def split_pdf_into_pages(pdf_path: Path) -> AsyncGenerator[list[Path], None]:
|
|
16
|
+
with tempfile.TemporaryDirectory(prefix="kiln_pdf_pages_") as temp_dir:
|
|
17
|
+
page_paths = []
|
|
18
|
+
|
|
19
|
+
with open(pdf_path, "rb") as file:
|
|
20
|
+
# Reader init can be heavy; offload to thread
|
|
21
|
+
pdf_reader = await asyncio.to_thread(PdfReader, file)
|
|
22
|
+
|
|
23
|
+
for page_num in range(len(pdf_reader.pages)):
|
|
24
|
+
await asyncio.sleep(0)
|
|
25
|
+
pdf_writer = PdfWriter()
|
|
26
|
+
pdf_writer.add_page(pdf_reader.pages[page_num])
|
|
27
|
+
|
|
28
|
+
# Create temporary file for this page
|
|
29
|
+
page_filename = f"page_{page_num + 1}.pdf"
|
|
30
|
+
page_path = Path(temp_dir) / page_filename
|
|
31
|
+
|
|
32
|
+
with open(page_path, "wb") as page_file:
|
|
33
|
+
# Writing/compression can be expensive; offload to thread
|
|
34
|
+
await asyncio.to_thread(pdf_writer.write, page_file)
|
|
35
|
+
|
|
36
|
+
page_paths.append(page_path)
|
|
37
|
+
|
|
38
|
+
yield page_paths
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from kiln_ai.datamodel.project import Project
|
|
2
|
+
from kiln_ai.utils.config import Config
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def project_from_id(project_id: str) -> Project | None:
|
|
6
|
+
project_paths = Config.shared().projects
|
|
7
|
+
if project_paths is not None:
|
|
8
|
+
for project_path in project_paths:
|
|
9
|
+
try:
|
|
10
|
+
project = Project.load_from_file(project_path)
|
|
11
|
+
if project.id == project_id:
|
|
12
|
+
return project
|
|
13
|
+
except Exception:
|
|
14
|
+
# deleted files are possible continue with the rest
|
|
15
|
+
continue
|
|
16
|
+
|
|
17
|
+
return None
|