kiln-ai 0.20.1__py3-none-any.whl → 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (117) hide show
  1. kiln_ai/adapters/__init__.py +6 -0
  2. kiln_ai/adapters/adapter_registry.py +43 -226
  3. kiln_ai/adapters/chunkers/__init__.py +13 -0
  4. kiln_ai/adapters/chunkers/base_chunker.py +42 -0
  5. kiln_ai/adapters/chunkers/chunker_registry.py +16 -0
  6. kiln_ai/adapters/chunkers/fixed_window_chunker.py +39 -0
  7. kiln_ai/adapters/chunkers/helpers.py +23 -0
  8. kiln_ai/adapters/chunkers/test_base_chunker.py +63 -0
  9. kiln_ai/adapters/chunkers/test_chunker_registry.py +28 -0
  10. kiln_ai/adapters/chunkers/test_fixed_window_chunker.py +346 -0
  11. kiln_ai/adapters/chunkers/test_helpers.py +75 -0
  12. kiln_ai/adapters/data_gen/test_data_gen_task.py +9 -3
  13. kiln_ai/adapters/embedding/__init__.py +0 -0
  14. kiln_ai/adapters/embedding/base_embedding_adapter.py +44 -0
  15. kiln_ai/adapters/embedding/embedding_registry.py +32 -0
  16. kiln_ai/adapters/embedding/litellm_embedding_adapter.py +199 -0
  17. kiln_ai/adapters/embedding/test_base_embedding_adapter.py +283 -0
  18. kiln_ai/adapters/embedding/test_embedding_registry.py +166 -0
  19. kiln_ai/adapters/embedding/test_litellm_embedding_adapter.py +1149 -0
  20. kiln_ai/adapters/eval/eval_runner.py +6 -2
  21. kiln_ai/adapters/eval/test_base_eval.py +1 -3
  22. kiln_ai/adapters/eval/test_g_eval.py +1 -1
  23. kiln_ai/adapters/extractors/__init__.py +18 -0
  24. kiln_ai/adapters/extractors/base_extractor.py +72 -0
  25. kiln_ai/adapters/extractors/encoding.py +20 -0
  26. kiln_ai/adapters/extractors/extractor_registry.py +44 -0
  27. kiln_ai/adapters/extractors/extractor_runner.py +112 -0
  28. kiln_ai/adapters/extractors/litellm_extractor.py +386 -0
  29. kiln_ai/adapters/extractors/test_base_extractor.py +244 -0
  30. kiln_ai/adapters/extractors/test_encoding.py +54 -0
  31. kiln_ai/adapters/extractors/test_extractor_registry.py +181 -0
  32. kiln_ai/adapters/extractors/test_extractor_runner.py +181 -0
  33. kiln_ai/adapters/extractors/test_litellm_extractor.py +1192 -0
  34. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +2 -2
  35. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +2 -6
  36. kiln_ai/adapters/fine_tune/test_together_finetune.py +2 -6
  37. kiln_ai/adapters/ml_embedding_model_list.py +192 -0
  38. kiln_ai/adapters/ml_model_list.py +382 -4
  39. kiln_ai/adapters/model_adapters/litellm_adapter.py +7 -69
  40. kiln_ai/adapters/model_adapters/test_litellm_adapter.py +1 -1
  41. kiln_ai/adapters/model_adapters/test_structured_output.py +3 -1
  42. kiln_ai/adapters/ollama_tools.py +69 -12
  43. kiln_ai/adapters/provider_tools.py +190 -46
  44. kiln_ai/adapters/rag/deduplication.py +49 -0
  45. kiln_ai/adapters/rag/progress.py +252 -0
  46. kiln_ai/adapters/rag/rag_runners.py +844 -0
  47. kiln_ai/adapters/rag/test_deduplication.py +195 -0
  48. kiln_ai/adapters/rag/test_progress.py +785 -0
  49. kiln_ai/adapters/rag/test_rag_runners.py +2376 -0
  50. kiln_ai/adapters/remote_config.py +80 -8
  51. kiln_ai/adapters/test_adapter_registry.py +579 -86
  52. kiln_ai/adapters/test_ml_embedding_model_list.py +429 -0
  53. kiln_ai/adapters/test_ml_model_list.py +212 -0
  54. kiln_ai/adapters/test_ollama_tools.py +340 -1
  55. kiln_ai/adapters/test_prompt_builders.py +1 -1
  56. kiln_ai/adapters/test_provider_tools.py +199 -8
  57. kiln_ai/adapters/test_remote_config.py +551 -56
  58. kiln_ai/adapters/vector_store/__init__.py +1 -0
  59. kiln_ai/adapters/vector_store/base_vector_store_adapter.py +83 -0
  60. kiln_ai/adapters/vector_store/lancedb_adapter.py +389 -0
  61. kiln_ai/adapters/vector_store/test_base_vector_store.py +160 -0
  62. kiln_ai/adapters/vector_store/test_lancedb_adapter.py +1841 -0
  63. kiln_ai/adapters/vector_store/test_vector_store_registry.py +199 -0
  64. kiln_ai/adapters/vector_store/vector_store_registry.py +33 -0
  65. kiln_ai/datamodel/__init__.py +16 -13
  66. kiln_ai/datamodel/basemodel.py +170 -1
  67. kiln_ai/datamodel/chunk.py +158 -0
  68. kiln_ai/datamodel/datamodel_enums.py +27 -0
  69. kiln_ai/datamodel/embedding.py +64 -0
  70. kiln_ai/datamodel/extraction.py +303 -0
  71. kiln_ai/datamodel/project.py +33 -1
  72. kiln_ai/datamodel/rag.py +79 -0
  73. kiln_ai/datamodel/test_attachment.py +649 -0
  74. kiln_ai/datamodel/test_basemodel.py +1 -1
  75. kiln_ai/datamodel/test_chunk_models.py +317 -0
  76. kiln_ai/datamodel/test_dataset_split.py +1 -1
  77. kiln_ai/datamodel/test_embedding_models.py +448 -0
  78. kiln_ai/datamodel/test_eval_model.py +6 -6
  79. kiln_ai/datamodel/test_extraction_chunk.py +206 -0
  80. kiln_ai/datamodel/test_extraction_model.py +470 -0
  81. kiln_ai/datamodel/test_rag.py +641 -0
  82. kiln_ai/datamodel/test_tool_id.py +81 -0
  83. kiln_ai/datamodel/test_vector_store.py +320 -0
  84. kiln_ai/datamodel/tool_id.py +22 -0
  85. kiln_ai/datamodel/vector_store.py +141 -0
  86. kiln_ai/tools/mcp_session_manager.py +4 -1
  87. kiln_ai/tools/rag_tools.py +157 -0
  88. kiln_ai/tools/test_mcp_session_manager.py +1 -1
  89. kiln_ai/tools/test_rag_tools.py +848 -0
  90. kiln_ai/tools/test_tool_registry.py +91 -2
  91. kiln_ai/tools/tool_registry.py +21 -0
  92. kiln_ai/utils/__init__.py +3 -0
  93. kiln_ai/utils/async_job_runner.py +62 -17
  94. kiln_ai/utils/config.py +2 -2
  95. kiln_ai/utils/env.py +15 -0
  96. kiln_ai/utils/filesystem.py +14 -0
  97. kiln_ai/utils/filesystem_cache.py +60 -0
  98. kiln_ai/utils/litellm.py +94 -0
  99. kiln_ai/utils/lock.py +100 -0
  100. kiln_ai/utils/mime_type.py +38 -0
  101. kiln_ai/utils/pdf_utils.py +38 -0
  102. kiln_ai/utils/test_async_job_runner.py +151 -35
  103. kiln_ai/utils/test_env.py +142 -0
  104. kiln_ai/utils/test_filesystem_cache.py +316 -0
  105. kiln_ai/utils/test_litellm.py +206 -0
  106. kiln_ai/utils/test_lock.py +185 -0
  107. kiln_ai/utils/test_mime_type.py +66 -0
  108. kiln_ai/utils/test_pdf_utils.py +73 -0
  109. kiln_ai/utils/test_uuid.py +111 -0
  110. kiln_ai/utils/test_validation.py +524 -0
  111. kiln_ai/utils/uuid.py +9 -0
  112. kiln_ai/utils/validation.py +90 -0
  113. {kiln_ai-0.20.1.dist-info → kiln_ai-0.21.0.dist-info}/METADATA +7 -1
  114. kiln_ai-0.21.0.dist-info/RECORD +211 -0
  115. kiln_ai-0.20.1.dist-info/RECORD +0 -138
  116. {kiln_ai-0.20.1.dist-info → kiln_ai-0.21.0.dist-info}/WHEEL +0 -0
  117. {kiln_ai-0.20.1.dist-info → kiln_ai-0.21.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -1,3 +1,4 @@
1
+ from pathlib import Path
1
2
  from unittest.mock import Mock
2
3
 
3
4
  import pytest
@@ -8,6 +9,7 @@ from kiln_ai.datamodel.task import Task
8
9
  from kiln_ai.datamodel.tool_id import (
9
10
  MCP_LOCAL_TOOL_ID_PREFIX,
10
11
  MCP_REMOTE_TOOL_ID_PREFIX,
12
+ RAG_TOOL_ID_PREFIX,
11
13
  KilnBuiltInToolId,
12
14
  _check_tool_id,
13
15
  mcp_server_and_tool_name_from_id,
@@ -143,7 +145,7 @@ class TestToolRegistry:
143
145
  tool_id = f"{MCP_LOCAL_TOOL_ID_PREFIX}test_server::test_tool"
144
146
  with pytest.raises(
145
147
  ValueError,
146
- match="Unable to resolve tool from id.*Requires a parent project/task",
148
+ match=r"Unable to resolve tool from id.*Requires a parent project/task",
147
149
  ):
148
150
  tool_from_id(tool_id, task=None)
149
151
 
@@ -181,6 +183,93 @@ class TestToolRegistry:
181
183
  ):
182
184
  tool_from_id(tool_id, task=mock_task)
183
185
 
186
+ def test_tool_from_id_rag_tool_success(self):
187
+ """Test that tool_from_id works with RAG tool IDs."""
188
+ # Create mock RAG config
189
+ from unittest.mock import patch
190
+
191
+ with (
192
+ patch("kiln_ai.tools.tool_registry.RagConfig") as mock_rag_config_class,
193
+ patch("kiln_ai.tools.rag_tools.RagTool") as mock_rag_tool_class,
194
+ ):
195
+ # Setup mock RAG config
196
+ mock_rag_config = Mock()
197
+ mock_rag_config.id = "test_rag_config"
198
+ mock_rag_config_class.from_id_and_parent_path.return_value = mock_rag_config
199
+
200
+ # Setup mock RAG tool
201
+ mock_rag_tool = Mock()
202
+ mock_rag_tool_class.return_value = mock_rag_tool
203
+
204
+ # Create mock project
205
+ mock_project = Mock(spec=Project)
206
+ mock_project.id = "test_project_id"
207
+ mock_project.path = Path("/test/path")
208
+
209
+ # Create mock task with parent project
210
+ mock_task = Mock(spec=Task)
211
+ mock_task.parent_project.return_value = mock_project
212
+
213
+ # Test with RAG tool ID
214
+ tool_id = f"{RAG_TOOL_ID_PREFIX}test_rag_config"
215
+ tool = tool_from_id(tool_id, task=mock_task)
216
+
217
+ # Verify the tool is RagTool
218
+ assert tool == mock_rag_tool
219
+ mock_rag_config_class.from_id_and_parent_path.assert_called_once_with(
220
+ "test_rag_config", Path("/test/path")
221
+ )
222
+ mock_rag_tool_class.assert_called_once_with(tool_id, mock_rag_config)
223
+
224
+ def test_tool_from_id_rag_tool_no_task(self):
225
+ """Test that RAG tool ID without task raises ValueError."""
226
+ tool_id = f"{RAG_TOOL_ID_PREFIX}test_rag_config"
227
+
228
+ with pytest.raises(
229
+ ValueError,
230
+ match=r"Unable to resolve tool from id.*Requires a parent project/task",
231
+ ):
232
+ tool_from_id(tool_id, task=None)
233
+
234
+ def test_tool_from_id_rag_tool_no_project(self):
235
+ """Test that RAG tool ID with task but no project raises ValueError."""
236
+ # Create mock task without parent project
237
+ mock_task = Mock(spec=Task)
238
+ mock_task.parent_project.return_value = None
239
+
240
+ tool_id = f"{RAG_TOOL_ID_PREFIX}test_rag_config"
241
+
242
+ with pytest.raises(
243
+ ValueError,
244
+ match=r"Unable to resolve tool from id.*Requires a parent project/task",
245
+ ):
246
+ tool_from_id(tool_id, task=mock_task)
247
+
248
+ def test_tool_from_id_rag_config_not_found(self):
249
+ """Test that RAG tool ID with missing RAG config raises ValueError."""
250
+ from unittest.mock import patch
251
+
252
+ with patch("kiln_ai.tools.tool_registry.RagConfig") as mock_rag_config_class:
253
+ # Setup mock to return None (config not found)
254
+ mock_rag_config_class.from_id_and_parent_path.return_value = None
255
+
256
+ # Create mock project
257
+ mock_project = Mock(spec=Project)
258
+ mock_project.id = "test_project_id"
259
+ mock_project.path = Path("/test/path")
260
+
261
+ # Create mock task with parent project
262
+ mock_task = Mock(spec=Task)
263
+ mock_task.parent_project.return_value = mock_project
264
+
265
+ tool_id = f"{RAG_TOOL_ID_PREFIX}missing_rag_config"
266
+
267
+ with pytest.raises(
268
+ ValueError,
269
+ match="RAG config not found: missing_rag_config in project test_project_id for tool",
270
+ ):
271
+ tool_from_id(tool_id, task=mock_task)
272
+
184
273
  def test_all_built_in_tools_are_registered(self):
185
274
  """Test that all KilnBuiltInToolId enum members are handled by the registry."""
186
275
  for tool_id in KilnBuiltInToolId:
@@ -406,7 +495,7 @@ class TestToolRegistry:
406
495
 
407
496
  with pytest.raises(
408
497
  ValueError,
409
- match="Unable to resolve tool from id.*Requires a parent project/task",
498
+ match=r"Unable to resolve tool from id.*Requires a parent project/task",
410
499
  ):
411
500
  tool_from_id(mcp_tool_id, task=None)
412
501
 
@@ -1,9 +1,12 @@
1
+ from kiln_ai.datamodel.rag import RagConfig
1
2
  from kiln_ai.datamodel.task import Task
2
3
  from kiln_ai.datamodel.tool_id import (
3
4
  MCP_LOCAL_TOOL_ID_PREFIX,
4
5
  MCP_REMOTE_TOOL_ID_PREFIX,
6
+ RAG_TOOL_ID_PREFIX,
5
7
  KilnBuiltInToolId,
6
8
  mcp_server_and_tool_name_from_id,
9
+ rag_config_id_from_id,
7
10
  )
8
11
  from kiln_ai.tools.base_tool import KilnToolInterface
9
12
  from kiln_ai.tools.built_in_tools.math_tools import (
@@ -60,5 +63,23 @@ def tool_from_id(tool_id: str, task: Task | None = None) -> KilnToolInterface:
60
63
  )
61
64
 
62
65
  return MCPServerTool(server, tool_name)
66
+ elif tool_id.startswith(RAG_TOOL_ID_PREFIX):
67
+ project = task.parent_project() if task is not None else None
68
+ if project is None:
69
+ raise ValueError(
70
+ f"Unable to resolve tool from id: {tool_id}. Requires a parent project/task."
71
+ )
72
+
73
+ rag_config_id = rag_config_id_from_id(tool_id)
74
+ rag_config = RagConfig.from_id_and_parent_path(rag_config_id, project.path)
75
+ if rag_config is None:
76
+ raise ValueError(
77
+ f"RAG config not found: {rag_config_id} in project {project.id} for tool {tool_id}"
78
+ )
79
+
80
+ # Lazy import to avoid circular dependency
81
+ from kiln_ai.tools.rag_tools import RagTool
82
+
83
+ return RagTool(tool_id, rag_config)
63
84
 
64
85
  raise ValueError(f"Tool ID {tool_id} not found in tool registry")
kiln_ai/utils/__init__.py CHANGED
@@ -5,8 +5,11 @@ Misc utilities used in the kiln_ai library.
5
5
  """
6
6
 
7
7
  from . import config, formatting
8
+ from .lock import AsyncLockManager, shared_async_lock_manager
8
9
 
9
10
  __all__ = [
11
+ "AsyncLockManager",
10
12
  "config",
11
13
  "formatting",
14
+ "shared_async_lock_manager",
12
15
  ]
@@ -1,7 +1,7 @@
1
1
  import asyncio
2
2
  import logging
3
3
  from dataclasses import dataclass
4
- from typing import AsyncGenerator, Awaitable, Callable, List, TypeVar
4
+ from typing import AsyncGenerator, Awaitable, Callable, Generic, List, TypeVar
5
5
 
6
6
  logger = logging.getLogger(__name__)
7
7
 
@@ -15,29 +15,66 @@ class Progress:
15
15
  errors: int
16
16
 
17
17
 
18
- class AsyncJobRunner:
19
- def __init__(self, concurrency: int = 1):
18
+ class AsyncJobRunnerObserver(Generic[T]):
19
+ async def on_error(self, job: T, error: Exception):
20
+ """
21
+ Called when a job raises an unhandled exception.
22
+ """
23
+ pass
24
+
25
+ async def on_success(self, job: T):
26
+ """
27
+ Called when a job completes successfully.
28
+ """
29
+ pass
30
+
31
+ async def on_job_start(self, job: T):
32
+ """
33
+ Called when a job starts.
34
+ """
35
+ pass
36
+
37
+
38
+ class AsyncJobRunner(Generic[T]):
39
+ def __init__(
40
+ self,
41
+ jobs: List[T],
42
+ run_job_fn: Callable[[T], Awaitable[bool]],
43
+ concurrency: int = 1,
44
+ observers: List[AsyncJobRunnerObserver[T]] | None = None,
45
+ ):
20
46
  if concurrency < 1:
21
47
  raise ValueError("concurrency must be ≥ 1")
22
48
  self.concurrency = concurrency
49
+ self.jobs = jobs
50
+ self.run_job_fn = run_job_fn
51
+ self.observers = observers or []
23
52
 
24
- async def run(
25
- self,
26
- jobs: List[T],
27
- run_job: Callable[[T], Awaitable[bool]],
28
- ) -> AsyncGenerator[Progress, None]:
53
+ async def notify_error(self, job: T, error: Exception):
54
+ for observer in self.observers:
55
+ await observer.on_error(job, error)
56
+
57
+ async def notify_success(self, job: T):
58
+ for observer in self.observers:
59
+ await observer.on_success(job)
60
+
61
+ async def notify_job_start(self, job: T):
62
+ for observer in self.observers:
63
+ await observer.on_job_start(job)
64
+
65
+ async def run(self) -> AsyncGenerator[Progress, None]:
29
66
  """
30
67
  Runs the jobs with parallel workers and yields progress updates.
31
68
  """
32
69
  complete = 0
33
70
  errors = 0
34
- total = len(jobs)
71
+ total = len(self.jobs)
35
72
 
36
73
  # Send initial status
37
74
  yield Progress(complete=complete, total=total, errors=errors)
38
75
 
39
76
  worker_queue: asyncio.Queue[T] = asyncio.Queue()
40
- for job in jobs:
77
+ for job in self.jobs:
41
78
  worker_queue.put_nowait(job)
42
79
 
43
80
  # simple status queue to return progress. True=success, False=error
@@ -46,7 +83,7 @@ class AsyncJobRunner:
46
83
  workers = []
47
84
  for _ in range(self.concurrency):
48
85
  task = asyncio.create_task(
49
- self._run_worker(worker_queue, status_queue, run_job),
86
+ self._run_worker(worker_queue, status_queue, self.run_job_fn),
50
87
  )
51
88
  workers.append(task)
52
89
 
@@ -64,7 +101,11 @@ class AsyncJobRunner:
64
101
  else:
65
102
  errors += 1
66
103
 
67
- yield Progress(complete=complete, total=total, errors=errors)
104
+ yield Progress(
105
+ complete=complete,
106
+ total=total,
107
+ errors=errors,
108
+ )
68
109
  except asyncio.TimeoutError:
69
110
  # Timeout is expected, just continue to recheck worker status
70
111
  # Don't love this but beats sentinels for reliability
@@ -82,7 +123,7 @@ class AsyncJobRunner:
82
123
  self,
83
124
  worker_queue: asyncio.Queue[T],
84
125
  status_queue: asyncio.Queue[bool],
85
- run_job: Callable[[T], Awaitable[bool]],
126
+ run_job_fn: Callable[[T], Awaitable[bool]],
86
127
  ):
87
128
  while True:
88
129
  try:
@@ -92,13 +133,17 @@ class AsyncJobRunner:
92
133
  break
93
134
 
94
135
  try:
95
- success = await run_job(job)
96
- except Exception:
136
+ await self.notify_job_start(job)
137
+ result = await run_job_fn(job)
138
+ if result:
139
+ await self.notify_success(job)
140
+ except Exception as e:
97
141
  logger.error("Job failed to complete", exc_info=True)
98
- success = False
142
+ await self.notify_error(job, e)
143
+ result = False
99
144
 
100
145
  try:
101
- await status_queue.put(success)
146
+ await status_queue.put(result)
102
147
  except Exception:
103
148
  logger.error("Failed to enqueue status for job", exc_info=True)
104
149
  finally:
kiln_ai/utils/config.py CHANGED
@@ -221,14 +221,14 @@ class Config:
221
221
  raise AttributeError(f"Config has no attribute '{name}'")
222
222
 
223
223
  @classmethod
224
- def settings_dir(cls, create=True):
224
+ def settings_dir(cls, create=True) -> str:
225
225
  settings_dir = os.path.join(Path.home(), ".kiln_ai")
226
226
  if create and not os.path.exists(settings_dir):
227
227
  os.makedirs(settings_dir)
228
228
  return settings_dir
229
229
 
230
230
  @classmethod
231
- def settings_path(cls, create=True):
231
+ def settings_path(cls, create=True) -> str:
232
232
  settings_dir = cls.settings_dir(create)
233
233
  return os.path.join(settings_dir, "settings.yaml")
234
234
 
kiln_ai/utils/env.py ADDED
@@ -0,0 +1,15 @@
1
+ import os
2
+ from contextlib import contextmanager
3
+
4
+
5
+ @contextmanager
6
+ def temporary_env(var_name: str, value: str):
7
+ old_value = os.environ.get(var_name)
8
+ os.environ[var_name] = value
9
+ try:
10
+ yield
11
+ finally:
12
+ if old_value is None:
13
+ os.environ.pop(var_name, None) # remove if it did not exist before
14
+ else:
15
+ os.environ[var_name] = old_value
@@ -0,0 +1,14 @@
1
+ import os
2
+ import subprocess
3
+ import sys
4
+ from pathlib import Path
5
+
6
+
7
+ def open_folder(path: str | Path) -> None:
8
+ log_dir = os.path.dirname(path)
9
+ if sys.platform.startswith("darwin"):
10
+ subprocess.run(["open", log_dir], check=True)
11
+ elif sys.platform.startswith("win"):
12
+ os.startfile(log_dir) # type: ignore[attr-defined]
13
+ else:
14
+ subprocess.run(["xdg-open", log_dir], check=True)
@@ -0,0 +1,60 @@
1
+ import logging
2
+ import tempfile
3
+ from pathlib import Path
4
+
5
+ import anyio
6
+
7
+ from kiln_ai.datamodel.basemodel import name_validator
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class FilesystemCache:
13
+ def __init__(self, path: Path):
14
+ self.cache_dir_path = path
15
+
16
+ def validate_key(self, key: str) -> None:
17
+ # throws if invalid
18
+ name_validator(min_length=1, max_length=120)(key)
19
+
20
+ def get_path(self, key: str) -> Path:
21
+ self.validate_key(key)
22
+ return self.cache_dir_path / key
23
+
24
+ async def get(self, key: str) -> bytes | None:
25
+ # check if the file exists - don't need to validate the key
26
+ # worst case we just return None
27
+ if not self.get_path(key).exists():
28
+ return None
29
+
30
+ # we don't want to raise because of internal cache corruption issues
31
+ try:
32
+ return await anyio.Path(self.get_path(key)).read_bytes()
33
+ except Exception:
34
+ logger.error(f"Error reading file {self.get_path(key)}", exc_info=True)
35
+ return None
36
+
37
+ async def set(self, key: str, value: bytes) -> Path:
38
+ logger.debug(f"Caching {key} at {self.get_path(key)}")
39
+ self.validate_key(key)
40
+ path = self.get_path(key)
41
+ await anyio.Path(path).write_bytes(value)
42
+ return path
43
+
44
+
45
+ class TemporaryFilesystemCache:
46
+ _shared_instance = None
47
+
48
+ def __init__(self):
49
+ self._cache_temp_dir = tempfile.mkdtemp(prefix="kiln_cache_")
50
+ self.filesystem_cache = FilesystemCache(path=Path(self._cache_temp_dir))
51
+
52
+ logger.debug(
53
+ f"Created temporary filesystem cache directory: {self._cache_temp_dir}"
54
+ )
55
+
56
+ @classmethod
57
+ def shared(cls) -> FilesystemCache:
58
+ if cls._shared_instance is None:
59
+ cls._shared_instance = cls()
60
+ return cls._shared_instance.filesystem_cache
@@ -0,0 +1,94 @@
1
+ from dataclasses import dataclass
2
+
3
+ from kiln_ai.adapters.ml_embedding_model_list import KilnEmbeddingModelProvider
4
+ from kiln_ai.adapters.ml_model_list import KilnModelProvider
5
+ from kiln_ai.datamodel.datamodel_enums import ModelProviderName
6
+ from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
7
+
8
+
9
+ @dataclass
10
+ class LitellmProviderInfo:
11
+ # The name of the provider, as it appears in litellm
12
+ provider_name: str
13
+ # Whether the provider is custom - e.g. custom models, ollama, fine tunes, and custom registry models
14
+ is_custom: bool
15
+ # The model ID slug to use in litellm
16
+ litellm_model_id: str
17
+
18
+
19
+ def get_litellm_provider_info(
20
+ model_provider: KilnEmbeddingModelProvider | KilnModelProvider,
21
+ ) -> LitellmProviderInfo:
22
+ """
23
+ Maps a Kiln model provider to a litellm provider.
24
+
25
+ Args:
26
+ model_provider: The model provider to get litellm provider info for
27
+
28
+ Returns:
29
+ LitellmProviderInfo containing the provider name and whether it's custom
30
+ """
31
+ if not model_provider.model_id:
32
+ raise ValueError("Model ID is required for OpenAI compatible models")
33
+
34
+ litellm_provider_name: str | None = None
35
+ is_custom = False
36
+ match model_provider.name:
37
+ case ModelProviderName.openrouter:
38
+ litellm_provider_name = "openrouter"
39
+ case ModelProviderName.openai:
40
+ litellm_provider_name = "openai"
41
+ case ModelProviderName.groq:
42
+ litellm_provider_name = "groq"
43
+ case ModelProviderName.anthropic:
44
+ litellm_provider_name = "anthropic"
45
+ case ModelProviderName.ollama:
46
+ # We don't let litellm use the Ollama API and muck with our requests. We use Ollama's OpenAI compatible API.
47
+ # This is because we're setting detailed features like response_format=json_schema and want lower level control.
48
+ is_custom = True
49
+ case ModelProviderName.docker_model_runner:
50
+ # Docker Model Runner uses OpenAI-compatible API, similar to Ollama
51
+ # We want direct control over the requests for features like response_format=json_schema
52
+ is_custom = True
53
+ case ModelProviderName.gemini_api:
54
+ litellm_provider_name = "gemini"
55
+ case ModelProviderName.fireworks_ai:
56
+ litellm_provider_name = "fireworks_ai"
57
+ case ModelProviderName.amazon_bedrock:
58
+ litellm_provider_name = "bedrock"
59
+ case ModelProviderName.azure_openai:
60
+ litellm_provider_name = "azure"
61
+ case ModelProviderName.huggingface:
62
+ litellm_provider_name = "huggingface"
63
+ case ModelProviderName.vertex:
64
+ litellm_provider_name = "vertex_ai"
65
+ case ModelProviderName.together_ai:
66
+ litellm_provider_name = "together_ai"
67
+ case ModelProviderName.cerebras:
68
+ litellm_provider_name = "cerebras"
69
+ case ModelProviderName.siliconflow_cn:
70
+ is_custom = True
71
+ case ModelProviderName.openai_compatible:
72
+ is_custom = True
73
+ case ModelProviderName.kiln_custom_registry:
74
+ is_custom = True
75
+ case ModelProviderName.kiln_fine_tune:
76
+ is_custom = True
77
+ case _:
78
+ raise_exhaustive_enum_error(model_provider.name)
79
+
80
+ if is_custom:
81
+ # Use openai as it's only used for format, not url
82
+ litellm_provider_name = "openai"
83
+
84
+ # Shouldn't be possible but keep type checker happy
85
+ if litellm_provider_name is None:
86
+ raise ValueError(
87
+ f"Provider name could not lookup valid litellm provider ID {model_provider.model_id}"
88
+ )
89
+
90
+ return LitellmProviderInfo(
91
+ provider_name=litellm_provider_name,
92
+ is_custom=is_custom,
93
+ litellm_model_id=f"{litellm_provider_name}/{model_provider.model_id}",
94
+ )
kiln_ai/utils/lock.py ADDED
@@ -0,0 +1,100 @@
1
+ import asyncio
2
+ from contextlib import asynccontextmanager
3
+ from dataclasses import dataclass, field
4
+ from typing import Dict, Hashable
5
+
6
+
7
+ @dataclass
8
+ class _Entry:
9
+ lock: asyncio.Lock = field(default_factory=asyncio.Lock)
10
+ waiters: int = 0 # tasks waiting to acquire
11
+ holders: int = 0 # 0 or 1 for a mutex
12
+
13
+
14
+ class AsyncLockManager:
15
+ """
16
+ A per-key asyncio lock manager that automatically cleans up locks when they're no longer needed.
17
+
18
+ Usage:
19
+ locks = AsyncLockManager()
20
+
21
+ async with locks.acquire("user:123"):
22
+ # critical section for "user:123"
23
+ ...
24
+
25
+ The manager removes a key when there are no holders and no waiters.
26
+ """
27
+
28
+ def __init__(self) -> None:
29
+ # Protects the _locks dict and bookkeeping counters.
30
+ self._mu = asyncio.Lock()
31
+ self._locks: Dict[Hashable, _Entry] = {}
32
+
33
+ @asynccontextmanager
34
+ async def acquire(self, key: Hashable, *, timeout: float | None = None):
35
+ """
36
+ Acquire the lock for `key` as an async context manager.
37
+
38
+ - `timeout`: optional seconds to wait; raises TimeoutError on expiry.
39
+ """
40
+ # Phase 1: register as a waiter and get/create the entry (under manager mutex).
41
+ async with self._mu:
42
+ entry = self._locks.get(key)
43
+ if entry is None:
44
+ entry = self._locks[key] = _Entry()
45
+ entry.waiters += 1
46
+
47
+ # Phase 2: wait on the per-key lock (outside manager mutex).
48
+ try:
49
+ if timeout is None:
50
+ await entry.lock.acquire()
51
+ else:
52
+ # Manual timeout to keep compatibility across Python versions.
53
+ await asyncio.wait_for(entry.lock.acquire(), timeout=timeout)
54
+
55
+ # Phase 3: update counters: became a holder.
56
+ async with self._mu:
57
+ entry.waiters -= 1
58
+ entry.holders += 1
59
+
60
+ try:
61
+ yield # critical section
62
+ finally:
63
+ # Phase 4: release holder and maybe cleanup.
64
+ entry.lock.release()
65
+ async with self._mu:
66
+ entry.holders -= 1
67
+ # Remove the entry if fully idle.
68
+ if entry.waiters == 0 and entry.holders == 0:
69
+ # Double-check we still point to same object (paranoia/race safety).
70
+ if self._locks.get(key) is entry:
71
+ del self._locks[key]
72
+
73
+ except asyncio.TimeoutError:
74
+ # Timed out while waiting; undo waiter count and maybe cleanup.
75
+ async with self._mu:
76
+ entry.waiters -= 1
77
+ if entry.waiters == 0 and entry.holders == 0:
78
+ if self._locks.get(key) is entry:
79
+ del self._locks[key]
80
+ raise
81
+ except asyncio.CancelledError:
82
+ # Cancelled while waiting; same cleanup as timeout.
83
+ async with self._mu:
84
+ entry.waiters -= 1
85
+ if entry.waiters == 0 and entry.holders == 0:
86
+ if self._locks.get(key) is entry:
87
+ del self._locks[key]
88
+ raise
89
+
90
+ # Optional: expose a snapshot for metrics/debugging
91
+ async def snapshot(self) -> Dict[Hashable, dict]:
92
+ async with self._mu:
93
+ return {
94
+ k: {"waiters": e.waiters, "holders": e.holders}
95
+ for k, e in self._locks.items()
96
+ }
97
+
98
+
99
+ # callers should use this global instance instead of creating their own
100
+ shared_async_lock_manager = AsyncLockManager()
@@ -0,0 +1,38 @@
1
+ import mimetypes
2
+
3
+
4
+ def guess_mime_type(filename: str) -> str | None:
5
+ filename_normalized = filename.lower()
6
+
7
+ # we override the mimetypes.guess_type for some common cases
8
+ # because it does not handle them correctly
9
+ if filename_normalized.endswith(".mov"):
10
+ return "video/quicktime"
11
+ elif filename_normalized.endswith(".mp3"):
12
+ return "audio/mpeg"
13
+ elif filename_normalized.endswith(".wav"):
14
+ return "audio/wav"
15
+ elif filename_normalized.endswith(".mp4"):
16
+ return "video/mp4"
17
+
18
+ mime_type, _ = mimetypes.guess_type(filename_normalized)
19
+ return mime_type
20
+
21
+
22
+ def guess_extension(mime_type: str) -> str | None:
23
+ mapping = {
24
+ "application/pdf": ".pdf",
25
+ "image/png": ".png",
26
+ "video/mp4": ".mp4",
27
+ "audio/ogg": ".ogg",
28
+ "text/markdown": ".md",
29
+ "text/plain": ".txt",
30
+ "text/html": ".html",
31
+ "text/csv": ".csv",
32
+ "image/jpeg": ".jpeg",
33
+ "image/jpg": ".jpeg",
34
+ "audio/mpeg": ".mp3",
35
+ "audio/wav": ".wav",
36
+ "video/quicktime": ".mov",
37
+ }
38
+ return mapping.get(mime_type)