kiln-ai 0.19.0__py3-none-any.whl → 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +8 -2
- kiln_ai/adapters/adapter_registry.py +43 -208
- kiln_ai/adapters/chat/chat_formatter.py +8 -12
- kiln_ai/adapters/chat/test_chat_formatter.py +6 -2
- kiln_ai/adapters/chunkers/__init__.py +13 -0
- kiln_ai/adapters/chunkers/base_chunker.py +42 -0
- kiln_ai/adapters/chunkers/chunker_registry.py +16 -0
- kiln_ai/adapters/chunkers/fixed_window_chunker.py +39 -0
- kiln_ai/adapters/chunkers/helpers.py +23 -0
- kiln_ai/adapters/chunkers/test_base_chunker.py +63 -0
- kiln_ai/adapters/chunkers/test_chunker_registry.py +28 -0
- kiln_ai/adapters/chunkers/test_fixed_window_chunker.py +346 -0
- kiln_ai/adapters/chunkers/test_helpers.py +75 -0
- kiln_ai/adapters/data_gen/test_data_gen_task.py +9 -3
- kiln_ai/adapters/docker_model_runner_tools.py +119 -0
- kiln_ai/adapters/embedding/__init__.py +0 -0
- kiln_ai/adapters/embedding/base_embedding_adapter.py +44 -0
- kiln_ai/adapters/embedding/embedding_registry.py +32 -0
- kiln_ai/adapters/embedding/litellm_embedding_adapter.py +199 -0
- kiln_ai/adapters/embedding/test_base_embedding_adapter.py +283 -0
- kiln_ai/adapters/embedding/test_embedding_registry.py +166 -0
- kiln_ai/adapters/embedding/test_litellm_embedding_adapter.py +1149 -0
- kiln_ai/adapters/eval/base_eval.py +2 -2
- kiln_ai/adapters/eval/eval_runner.py +9 -3
- kiln_ai/adapters/eval/g_eval.py +2 -2
- kiln_ai/adapters/eval/test_base_eval.py +2 -4
- kiln_ai/adapters/eval/test_g_eval.py +4 -5
- kiln_ai/adapters/extractors/__init__.py +18 -0
- kiln_ai/adapters/extractors/base_extractor.py +72 -0
- kiln_ai/adapters/extractors/encoding.py +20 -0
- kiln_ai/adapters/extractors/extractor_registry.py +44 -0
- kiln_ai/adapters/extractors/extractor_runner.py +112 -0
- kiln_ai/adapters/extractors/litellm_extractor.py +386 -0
- kiln_ai/adapters/extractors/test_base_extractor.py +244 -0
- kiln_ai/adapters/extractors/test_encoding.py +54 -0
- kiln_ai/adapters/extractors/test_extractor_registry.py +181 -0
- kiln_ai/adapters/extractors/test_extractor_runner.py +181 -0
- kiln_ai/adapters/extractors/test_litellm_extractor.py +1192 -0
- kiln_ai/adapters/fine_tune/__init__.py +1 -1
- kiln_ai/adapters/fine_tune/openai_finetune.py +14 -4
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +2 -2
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +2 -6
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +108 -111
- kiln_ai/adapters/fine_tune/test_together_finetune.py +2 -6
- kiln_ai/adapters/ml_embedding_model_list.py +192 -0
- kiln_ai/adapters/ml_model_list.py +761 -37
- kiln_ai/adapters/model_adapters/base_adapter.py +51 -21
- kiln_ai/adapters/model_adapters/litellm_adapter.py +380 -138
- kiln_ai/adapters/model_adapters/test_base_adapter.py +193 -17
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +407 -2
- kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +1103 -0
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +5 -5
- kiln_ai/adapters/model_adapters/test_structured_output.py +113 -5
- kiln_ai/adapters/ollama_tools.py +69 -12
- kiln_ai/adapters/parsers/__init__.py +1 -1
- kiln_ai/adapters/provider_tools.py +205 -47
- kiln_ai/adapters/rag/deduplication.py +49 -0
- kiln_ai/adapters/rag/progress.py +252 -0
- kiln_ai/adapters/rag/rag_runners.py +844 -0
- kiln_ai/adapters/rag/test_deduplication.py +195 -0
- kiln_ai/adapters/rag/test_progress.py +785 -0
- kiln_ai/adapters/rag/test_rag_runners.py +2376 -0
- kiln_ai/adapters/remote_config.py +80 -8
- kiln_ai/adapters/repair/test_repair_task.py +12 -9
- kiln_ai/adapters/run_output.py +3 -0
- kiln_ai/adapters/test_adapter_registry.py +657 -85
- kiln_ai/adapters/test_docker_model_runner_tools.py +305 -0
- kiln_ai/adapters/test_ml_embedding_model_list.py +429 -0
- kiln_ai/adapters/test_ml_model_list.py +251 -1
- kiln_ai/adapters/test_ollama_tools.py +340 -1
- kiln_ai/adapters/test_prompt_adaptors.py +13 -6
- kiln_ai/adapters/test_prompt_builders.py +1 -1
- kiln_ai/adapters/test_provider_tools.py +254 -8
- kiln_ai/adapters/test_remote_config.py +651 -58
- kiln_ai/adapters/vector_store/__init__.py +1 -0
- kiln_ai/adapters/vector_store/base_vector_store_adapter.py +83 -0
- kiln_ai/adapters/vector_store/lancedb_adapter.py +389 -0
- kiln_ai/adapters/vector_store/test_base_vector_store.py +160 -0
- kiln_ai/adapters/vector_store/test_lancedb_adapter.py +1841 -0
- kiln_ai/adapters/vector_store/test_vector_store_registry.py +199 -0
- kiln_ai/adapters/vector_store/vector_store_registry.py +33 -0
- kiln_ai/datamodel/__init__.py +39 -34
- kiln_ai/datamodel/basemodel.py +170 -1
- kiln_ai/datamodel/chunk.py +158 -0
- kiln_ai/datamodel/datamodel_enums.py +28 -0
- kiln_ai/datamodel/embedding.py +64 -0
- kiln_ai/datamodel/eval.py +1 -1
- kiln_ai/datamodel/external_tool_server.py +298 -0
- kiln_ai/datamodel/extraction.py +303 -0
- kiln_ai/datamodel/json_schema.py +25 -10
- kiln_ai/datamodel/project.py +40 -1
- kiln_ai/datamodel/rag.py +79 -0
- kiln_ai/datamodel/registry.py +0 -15
- kiln_ai/datamodel/run_config.py +62 -0
- kiln_ai/datamodel/task.py +2 -77
- kiln_ai/datamodel/task_output.py +6 -1
- kiln_ai/datamodel/task_run.py +41 -0
- kiln_ai/datamodel/test_attachment.py +649 -0
- kiln_ai/datamodel/test_basemodel.py +4 -4
- kiln_ai/datamodel/test_chunk_models.py +317 -0
- kiln_ai/datamodel/test_dataset_split.py +1 -1
- kiln_ai/datamodel/test_embedding_models.py +448 -0
- kiln_ai/datamodel/test_eval_model.py +6 -6
- kiln_ai/datamodel/test_example_models.py +175 -0
- kiln_ai/datamodel/test_external_tool_server.py +691 -0
- kiln_ai/datamodel/test_extraction_chunk.py +206 -0
- kiln_ai/datamodel/test_extraction_model.py +470 -0
- kiln_ai/datamodel/test_rag.py +641 -0
- kiln_ai/datamodel/test_registry.py +8 -3
- kiln_ai/datamodel/test_task.py +15 -47
- kiln_ai/datamodel/test_tool_id.py +320 -0
- kiln_ai/datamodel/test_vector_store.py +320 -0
- kiln_ai/datamodel/tool_id.py +105 -0
- kiln_ai/datamodel/vector_store.py +141 -0
- kiln_ai/tools/__init__.py +8 -0
- kiln_ai/tools/base_tool.py +82 -0
- kiln_ai/tools/built_in_tools/__init__.py +13 -0
- kiln_ai/tools/built_in_tools/math_tools.py +124 -0
- kiln_ai/tools/built_in_tools/test_math_tools.py +204 -0
- kiln_ai/tools/mcp_server_tool.py +95 -0
- kiln_ai/tools/mcp_session_manager.py +246 -0
- kiln_ai/tools/rag_tools.py +157 -0
- kiln_ai/tools/test_base_tools.py +199 -0
- kiln_ai/tools/test_mcp_server_tool.py +457 -0
- kiln_ai/tools/test_mcp_session_manager.py +1585 -0
- kiln_ai/tools/test_rag_tools.py +848 -0
- kiln_ai/tools/test_tool_registry.py +562 -0
- kiln_ai/tools/tool_registry.py +85 -0
- kiln_ai/utils/__init__.py +3 -0
- kiln_ai/utils/async_job_runner.py +62 -17
- kiln_ai/utils/config.py +24 -2
- kiln_ai/utils/env.py +15 -0
- kiln_ai/utils/filesystem.py +14 -0
- kiln_ai/utils/filesystem_cache.py +60 -0
- kiln_ai/utils/litellm.py +94 -0
- kiln_ai/utils/lock.py +100 -0
- kiln_ai/utils/mime_type.py +38 -0
- kiln_ai/utils/open_ai_types.py +94 -0
- kiln_ai/utils/pdf_utils.py +38 -0
- kiln_ai/utils/project_utils.py +17 -0
- kiln_ai/utils/test_async_job_runner.py +151 -35
- kiln_ai/utils/test_config.py +138 -1
- kiln_ai/utils/test_env.py +142 -0
- kiln_ai/utils/test_filesystem_cache.py +316 -0
- kiln_ai/utils/test_litellm.py +206 -0
- kiln_ai/utils/test_lock.py +185 -0
- kiln_ai/utils/test_mime_type.py +66 -0
- kiln_ai/utils/test_open_ai_types.py +131 -0
- kiln_ai/utils/test_pdf_utils.py +73 -0
- kiln_ai/utils/test_uuid.py +111 -0
- kiln_ai/utils/test_validation.py +524 -0
- kiln_ai/utils/uuid.py +9 -0
- kiln_ai/utils/validation.py +90 -0
- {kiln_ai-0.19.0.dist-info → kiln_ai-0.21.0.dist-info}/METADATA +12 -5
- kiln_ai-0.21.0.dist-info/RECORD +211 -0
- kiln_ai-0.19.0.dist-info/RECORD +0 -115
- {kiln_ai-0.19.0.dist-info → kiln_ai-0.21.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.19.0.dist-info → kiln_ai-0.21.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -0,0 +1,316 @@
|
|
|
1
|
+
import tempfile
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from unittest.mock import AsyncMock, patch
|
|
4
|
+
|
|
5
|
+
import pytest
|
|
6
|
+
|
|
7
|
+
from kiln_ai.utils.filesystem_cache import FilesystemCache, TemporaryFilesystemCache
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class TestFilesystemCache:
|
|
11
|
+
@pytest.fixture
|
|
12
|
+
def temp_dir(self):
|
|
13
|
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
14
|
+
yield Path(tmp_dir)
|
|
15
|
+
|
|
16
|
+
@pytest.fixture
|
|
17
|
+
def cache(self, temp_dir):
|
|
18
|
+
return FilesystemCache(temp_dir)
|
|
19
|
+
|
|
20
|
+
def test_init(self, temp_dir):
|
|
21
|
+
cache = FilesystemCache(temp_dir)
|
|
22
|
+
assert cache.cache_dir_path == temp_dir
|
|
23
|
+
assert cache.validate_key is not None
|
|
24
|
+
|
|
25
|
+
def test_get_path_valid_key(self, cache):
|
|
26
|
+
key = "test_file"
|
|
27
|
+
expected_path = cache.cache_dir_path / key
|
|
28
|
+
assert cache.get_path(key) == expected_path
|
|
29
|
+
|
|
30
|
+
def test_get_path_invalid_key_empty(self, cache):
|
|
31
|
+
with pytest.raises(ValueError):
|
|
32
|
+
cache.get_path("")
|
|
33
|
+
|
|
34
|
+
def test_get_path_invalid_key_too_long(self, cache):
|
|
35
|
+
long_key = "x" * 121 # exceeds max_length=120
|
|
36
|
+
with pytest.raises(ValueError):
|
|
37
|
+
cache.get_path(long_key)
|
|
38
|
+
|
|
39
|
+
def test_get_path_invalid_key_special_chars(self, cache):
|
|
40
|
+
with pytest.raises(ValueError):
|
|
41
|
+
cache.get_path("invalid/key")
|
|
42
|
+
|
|
43
|
+
async def test_get_nonexistent_file(self, cache):
|
|
44
|
+
result = await cache.get("nonexistent")
|
|
45
|
+
assert result is None
|
|
46
|
+
|
|
47
|
+
async def test_get_existing_file(self, cache, temp_dir):
|
|
48
|
+
key = "test"
|
|
49
|
+
content = b"Hello, World!"
|
|
50
|
+
file_path = temp_dir / key
|
|
51
|
+
file_path.write_bytes(content)
|
|
52
|
+
|
|
53
|
+
result = await cache.get(key)
|
|
54
|
+
assert result == content
|
|
55
|
+
|
|
56
|
+
async def test_get_file_read_error(self, cache, temp_dir):
|
|
57
|
+
key = "test"
|
|
58
|
+
file_path = temp_dir / key
|
|
59
|
+
file_path.write_bytes(b"test")
|
|
60
|
+
|
|
61
|
+
with patch(
|
|
62
|
+
"anyio.Path.read_bytes",
|
|
63
|
+
new_callable=AsyncMock,
|
|
64
|
+
side_effect=IOError("Read error"),
|
|
65
|
+
):
|
|
66
|
+
result = await cache.get(key)
|
|
67
|
+
assert result is None
|
|
68
|
+
|
|
69
|
+
async def test_set_valid_data(self, cache, temp_dir):
|
|
70
|
+
key = "test"
|
|
71
|
+
content = b"Hello, World!"
|
|
72
|
+
|
|
73
|
+
result_path = await cache.set(key, content)
|
|
74
|
+
expected_path = temp_dir / key
|
|
75
|
+
|
|
76
|
+
assert result_path == expected_path
|
|
77
|
+
assert expected_path.exists()
|
|
78
|
+
assert expected_path.read_bytes() == content
|
|
79
|
+
|
|
80
|
+
async def test_set_invalid_key_empty(self, cache):
|
|
81
|
+
with pytest.raises(ValueError):
|
|
82
|
+
await cache.set("", b"content")
|
|
83
|
+
|
|
84
|
+
async def test_set_invalid_key_too_long(self, cache):
|
|
85
|
+
long_key = "x" * 121
|
|
86
|
+
with pytest.raises(ValueError):
|
|
87
|
+
await cache.set(long_key, b"content")
|
|
88
|
+
|
|
89
|
+
async def test_set_invalid_key_special_chars(self, cache):
|
|
90
|
+
with pytest.raises(ValueError):
|
|
91
|
+
await cache.set("invalid/key", b"content")
|
|
92
|
+
|
|
93
|
+
async def test_set_overwrites_existing(self, cache, temp_dir):
|
|
94
|
+
key = "test"
|
|
95
|
+
original_content = b"Original content"
|
|
96
|
+
new_content = b"New content"
|
|
97
|
+
|
|
98
|
+
# Set original content
|
|
99
|
+
await cache.set(key, original_content)
|
|
100
|
+
assert (temp_dir / key).read_bytes() == original_content
|
|
101
|
+
|
|
102
|
+
# Overwrite with new content
|
|
103
|
+
await cache.set(key, new_content)
|
|
104
|
+
assert (temp_dir / key).read_bytes() == new_content
|
|
105
|
+
|
|
106
|
+
async def test_set_creates_parent_directory(self, cache, temp_dir):
|
|
107
|
+
# Test that set method creates parent directories if they don't exist
|
|
108
|
+
# Note: This test demonstrates a limitation - the current implementation
|
|
109
|
+
# doesn't support directory paths due to name_validator restrictions
|
|
110
|
+
key = "subdir_test" # Using underscore instead of slash
|
|
111
|
+
content = b"Test content"
|
|
112
|
+
|
|
113
|
+
result_path = await cache.set(key, content)
|
|
114
|
+
expected_path = temp_dir / key
|
|
115
|
+
|
|
116
|
+
assert result_path == expected_path
|
|
117
|
+
assert expected_path.exists()
|
|
118
|
+
assert expected_path.read_bytes() == content
|
|
119
|
+
|
|
120
|
+
async def test_set_with_nested_directories(self, cache, temp_dir):
|
|
121
|
+
# Test that set method creates deeply nested directories
|
|
122
|
+
# Note: This test demonstrates a limitation - the current implementation
|
|
123
|
+
# doesn't support directory paths due to name_validator restrictions
|
|
124
|
+
key = "level1_level2_level3_test" # Using underscores instead of slashes
|
|
125
|
+
content = b"Deeply nested content"
|
|
126
|
+
|
|
127
|
+
result_path = await cache.set(key, content)
|
|
128
|
+
expected_path = temp_dir / key
|
|
129
|
+
|
|
130
|
+
assert result_path == expected_path
|
|
131
|
+
assert expected_path.exists()
|
|
132
|
+
assert expected_path.read_bytes() == content
|
|
133
|
+
|
|
134
|
+
async def test_directory_paths_not_supported(self, cache):
|
|
135
|
+
# Test that demonstrates the current limitation - directory paths are not supported
|
|
136
|
+
# due to name_validator forbidding forward slashes
|
|
137
|
+
with pytest.raises(ValueError, match="Name is invalid"):
|
|
138
|
+
await cache.set("subdir/test", b"content")
|
|
139
|
+
|
|
140
|
+
with pytest.raises(ValueError, match="Name is invalid"):
|
|
141
|
+
await cache.get("subdir/test")
|
|
142
|
+
|
|
143
|
+
async def test_roundtrip_get_set(self, cache):
|
|
144
|
+
key = "roundtrip"
|
|
145
|
+
content = b"Roundtrip test content"
|
|
146
|
+
|
|
147
|
+
# Set content
|
|
148
|
+
await cache.set(key, content)
|
|
149
|
+
|
|
150
|
+
# Get content back
|
|
151
|
+
retrieved = await cache.get(key)
|
|
152
|
+
assert retrieved == content
|
|
153
|
+
|
|
154
|
+
async def test_multiple_files(self, cache, temp_dir):
|
|
155
|
+
files = {
|
|
156
|
+
"file1": b"Content 1",
|
|
157
|
+
"file2": b"Content 2",
|
|
158
|
+
"subdir_file3": b"Content 3", # Using underscore instead of slash
|
|
159
|
+
}
|
|
160
|
+
|
|
161
|
+
# Set all files
|
|
162
|
+
for key, content in files.items():
|
|
163
|
+
await cache.set(key, content)
|
|
164
|
+
|
|
165
|
+
# Verify all files exist and have correct content
|
|
166
|
+
for key, expected_content in files.items():
|
|
167
|
+
actual_content = await cache.get(key)
|
|
168
|
+
assert actual_content == expected_content
|
|
169
|
+
|
|
170
|
+
async def test_empty_bytes(self, cache):
|
|
171
|
+
key = "empty"
|
|
172
|
+
content = b""
|
|
173
|
+
|
|
174
|
+
await cache.set(key, content)
|
|
175
|
+
retrieved = await cache.get(key)
|
|
176
|
+
assert retrieved == content
|
|
177
|
+
|
|
178
|
+
async def test_large_content(self, cache):
|
|
179
|
+
key = "large"
|
|
180
|
+
content = b"x" * 10000 # 10KB of data
|
|
181
|
+
|
|
182
|
+
await cache.set(key, content)
|
|
183
|
+
retrieved = await cache.get(key)
|
|
184
|
+
assert retrieved == content
|
|
185
|
+
|
|
186
|
+
@pytest.mark.parametrize(
|
|
187
|
+
"unicode_text",
|
|
188
|
+
[
|
|
189
|
+
"Simple ASCII text",
|
|
190
|
+
"中文文本测试",
|
|
191
|
+
"Mixed 中文 and English text",
|
|
192
|
+
"Emojis: 🎉🚀💻🔥",
|
|
193
|
+
"Complex: 你好世界! Hello 世界! 🌍 This is 测试 with 中文, emojis 🚀, and English.",
|
|
194
|
+
"Special chars: ñáéíóú àèìòù çüöä",
|
|
195
|
+
"Math symbols: ∑∆π∫∞±≤≥≠",
|
|
196
|
+
"Currency: €£¥$₹₽",
|
|
197
|
+
],
|
|
198
|
+
)
|
|
199
|
+
async def test_unicode_text_retrieval_integrity(self, cache, unicode_text):
|
|
200
|
+
# Test that Unicode text is not corrupted during storage and retrieval
|
|
201
|
+
key = "unicode_integrity"
|
|
202
|
+
content = unicode_text.encode("utf-8")
|
|
203
|
+
|
|
204
|
+
# Store the text
|
|
205
|
+
await cache.set(key, content)
|
|
206
|
+
|
|
207
|
+
# Retrieve and verify
|
|
208
|
+
retrieved = await cache.get(key)
|
|
209
|
+
assert retrieved == content
|
|
210
|
+
assert retrieved.decode("utf-8") == unicode_text
|
|
211
|
+
|
|
212
|
+
async def test_key_overwrite_behavior(self, cache):
|
|
213
|
+
# Test that setting at the same key overwrites whatever was there
|
|
214
|
+
key = "overwrite_test"
|
|
215
|
+
|
|
216
|
+
# Set initial content
|
|
217
|
+
initial_content = "Initial content".encode("utf-8")
|
|
218
|
+
await cache.set(key, initial_content)
|
|
219
|
+
|
|
220
|
+
# Verify initial content is stored
|
|
221
|
+
retrieved = await cache.get(key)
|
|
222
|
+
assert retrieved == initial_content
|
|
223
|
+
assert retrieved.decode("utf-8") == "Initial content"
|
|
224
|
+
|
|
225
|
+
# Overwrite with different content
|
|
226
|
+
new_content = "New content with 中文 and emojis 🚀".encode("utf-8")
|
|
227
|
+
await cache.set(key, new_content)
|
|
228
|
+
|
|
229
|
+
# Verify the content was overwritten
|
|
230
|
+
retrieved = await cache.get(key)
|
|
231
|
+
assert retrieved == new_content
|
|
232
|
+
assert retrieved.decode("utf-8") == "New content with 中文 and emojis 🚀"
|
|
233
|
+
assert retrieved != initial_content
|
|
234
|
+
|
|
235
|
+
# Overwrite again with empty content
|
|
236
|
+
empty_content = b""
|
|
237
|
+
await cache.set(key, empty_content)
|
|
238
|
+
|
|
239
|
+
# Verify empty content is stored
|
|
240
|
+
retrieved = await cache.get(key)
|
|
241
|
+
assert retrieved == empty_content
|
|
242
|
+
assert retrieved.decode("utf-8") == ""
|
|
243
|
+
|
|
244
|
+
@pytest.mark.parametrize(
|
|
245
|
+
"invalid_key",
|
|
246
|
+
[
|
|
247
|
+
"",
|
|
248
|
+
"x" * 121,
|
|
249
|
+
"invalid/key",
|
|
250
|
+
"invalid\\key",
|
|
251
|
+
"invalid:key",
|
|
252
|
+
"invalid*key",
|
|
253
|
+
"invalid?key",
|
|
254
|
+
"invalid<key",
|
|
255
|
+
"invalid>key",
|
|
256
|
+
"invalid|key",
|
|
257
|
+
],
|
|
258
|
+
)
|
|
259
|
+
async def test_invalid_keys(self, cache, invalid_key):
|
|
260
|
+
with pytest.raises(ValueError):
|
|
261
|
+
cache.get_path(invalid_key)
|
|
262
|
+
|
|
263
|
+
with pytest.raises(ValueError):
|
|
264
|
+
await cache.set(invalid_key, b"content")
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
class TestTemporaryFilesystemCache:
|
|
268
|
+
def test_temporary_cache_creation(self):
|
|
269
|
+
"""Test that TemporaryFilesystemCache creates a temporary directory."""
|
|
270
|
+
temp_cache = TemporaryFilesystemCache()
|
|
271
|
+
|
|
272
|
+
# Should create a temporary directory
|
|
273
|
+
assert temp_cache._cache_temp_dir is not None
|
|
274
|
+
assert Path(temp_cache._cache_temp_dir).exists()
|
|
275
|
+
|
|
276
|
+
# Check that the directory name (not full path) starts with the prefix
|
|
277
|
+
temp_dir_name = Path(temp_cache._cache_temp_dir).name
|
|
278
|
+
assert temp_dir_name.startswith("kiln_cache_")
|
|
279
|
+
|
|
280
|
+
# Should have a FilesystemCache instance
|
|
281
|
+
assert isinstance(temp_cache.filesystem_cache, FilesystemCache)
|
|
282
|
+
assert temp_cache.filesystem_cache.cache_dir_path == Path(
|
|
283
|
+
temp_cache._cache_temp_dir
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
async def test_multiple_instances_share_same_cache(self):
|
|
287
|
+
"""Test that multiple calls to shared() return the same cache instance."""
|
|
288
|
+
# Get cache instances
|
|
289
|
+
cache1 = TemporaryFilesystemCache.shared()
|
|
290
|
+
cache2 = TemporaryFilesystemCache.shared()
|
|
291
|
+
|
|
292
|
+
# Should be the same instance
|
|
293
|
+
assert cache1 is cache2
|
|
294
|
+
|
|
295
|
+
# Test that they share the same cache directory
|
|
296
|
+
assert cache1.cache_dir_path == cache2.cache_dir_path
|
|
297
|
+
|
|
298
|
+
# Test that content set in one is available in the other
|
|
299
|
+
key = "shared_test"
|
|
300
|
+
content = b"shared content"
|
|
301
|
+
|
|
302
|
+
await cache1.set(key, content)
|
|
303
|
+
retrieved = await cache2.get(key)
|
|
304
|
+
assert retrieved == content
|
|
305
|
+
|
|
306
|
+
def test_cache_directory_naming(self):
|
|
307
|
+
"""Test that the temporary cache directory has the correct naming pattern."""
|
|
308
|
+
temp_cache = TemporaryFilesystemCache()
|
|
309
|
+
temp_dir_name = Path(temp_cache._cache_temp_dir).name
|
|
310
|
+
|
|
311
|
+
# Should start with the expected prefix
|
|
312
|
+
assert temp_dir_name.startswith("kiln_cache_")
|
|
313
|
+
|
|
314
|
+
# Should be a valid directory name
|
|
315
|
+
assert len(temp_dir_name) > len("kiln_cache_")
|
|
316
|
+
assert "/" not in temp_dir_name # Should not contain path separators
|
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
|
|
3
|
+
from kiln_ai.adapters.ml_embedding_model_list import KilnEmbeddingModelProvider
|
|
4
|
+
from kiln_ai.adapters.ml_model_list import KilnModelProvider
|
|
5
|
+
from kiln_ai.datamodel.datamodel_enums import ModelProviderName
|
|
6
|
+
from kiln_ai.utils.litellm import LitellmProviderInfo, get_litellm_provider_info
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class TestGetLitellmProviderInfo:
|
|
10
|
+
"""Test cases for get_litellm_provider_info function"""
|
|
11
|
+
|
|
12
|
+
@pytest.fixture
|
|
13
|
+
def sample_model_id(self):
|
|
14
|
+
"""Sample model ID for testing"""
|
|
15
|
+
return "test-model-id"
|
|
16
|
+
|
|
17
|
+
@pytest.fixture
|
|
18
|
+
def embedding_provider(self, sample_model_id):
|
|
19
|
+
"""Sample KilnEmbeddingModelProvider for testing"""
|
|
20
|
+
return KilnEmbeddingModelProvider(
|
|
21
|
+
name=ModelProviderName.openai,
|
|
22
|
+
model_id=sample_model_id,
|
|
23
|
+
n_dimensions=1536,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
@pytest.fixture
|
|
27
|
+
def model_provider(self, sample_model_id):
|
|
28
|
+
"""Sample KilnModelProvider for testing"""
|
|
29
|
+
return KilnModelProvider(
|
|
30
|
+
name=ModelProviderName.openai,
|
|
31
|
+
model_id=sample_model_id,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
@pytest.mark.parametrize(
|
|
35
|
+
"model_id",
|
|
36
|
+
[
|
|
37
|
+
None,
|
|
38
|
+
"",
|
|
39
|
+
],
|
|
40
|
+
)
|
|
41
|
+
def test_missing_model_id_raises_error(self, model_id):
|
|
42
|
+
"""Test that missing model_id raises ValueError"""
|
|
43
|
+
provider = KilnModelProvider(
|
|
44
|
+
name=ModelProviderName.openai,
|
|
45
|
+
model_id=model_id,
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
with pytest.raises(
|
|
49
|
+
ValueError, match="Model ID is required for OpenAI compatible models"
|
|
50
|
+
):
|
|
51
|
+
get_litellm_provider_info(provider)
|
|
52
|
+
|
|
53
|
+
@pytest.mark.parametrize(
|
|
54
|
+
"provider_name,expected_litellm_name,expected_is_custom",
|
|
55
|
+
[
|
|
56
|
+
(ModelProviderName.openrouter, "openrouter", False),
|
|
57
|
+
(ModelProviderName.openai, "openai", False),
|
|
58
|
+
(ModelProviderName.groq, "groq", False),
|
|
59
|
+
(ModelProviderName.anthropic, "anthropic", False),
|
|
60
|
+
(ModelProviderName.gemini_api, "gemini", False),
|
|
61
|
+
(ModelProviderName.fireworks_ai, "fireworks_ai", False),
|
|
62
|
+
(ModelProviderName.amazon_bedrock, "bedrock", False),
|
|
63
|
+
(ModelProviderName.azure_openai, "azure", False),
|
|
64
|
+
(ModelProviderName.huggingface, "huggingface", False),
|
|
65
|
+
(ModelProviderName.vertex, "vertex_ai", False),
|
|
66
|
+
(ModelProviderName.together_ai, "together_ai", False),
|
|
67
|
+
(ModelProviderName.ollama, "openai", True),
|
|
68
|
+
(ModelProviderName.openai_compatible, "openai", True),
|
|
69
|
+
(ModelProviderName.kiln_custom_registry, "openai", True),
|
|
70
|
+
(ModelProviderName.kiln_fine_tune, "openai", True),
|
|
71
|
+
],
|
|
72
|
+
)
|
|
73
|
+
def test_provider_mappings_with_model_provider(
|
|
74
|
+
self, provider_name, expected_litellm_name, expected_is_custom, sample_model_id
|
|
75
|
+
):
|
|
76
|
+
"""Test provider name mappings for KilnModelProvider"""
|
|
77
|
+
provider = KilnModelProvider(
|
|
78
|
+
name=provider_name,
|
|
79
|
+
model_id=sample_model_id,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
result = get_litellm_provider_info(provider)
|
|
83
|
+
|
|
84
|
+
assert isinstance(result, LitellmProviderInfo)
|
|
85
|
+
assert result.provider_name == expected_litellm_name
|
|
86
|
+
assert result.is_custom == expected_is_custom
|
|
87
|
+
assert result.litellm_model_id == f"{expected_litellm_name}/{sample_model_id}"
|
|
88
|
+
|
|
89
|
+
@pytest.mark.parametrize(
|
|
90
|
+
"provider_name,expected_litellm_name,expected_is_custom",
|
|
91
|
+
[
|
|
92
|
+
(ModelProviderName.openrouter, "openrouter", False),
|
|
93
|
+
(ModelProviderName.openai, "openai", False),
|
|
94
|
+
(ModelProviderName.groq, "groq", False),
|
|
95
|
+
(ModelProviderName.anthropic, "anthropic", False),
|
|
96
|
+
(ModelProviderName.gemini_api, "gemini", False),
|
|
97
|
+
(ModelProviderName.fireworks_ai, "fireworks_ai", False),
|
|
98
|
+
(ModelProviderName.amazon_bedrock, "bedrock", False),
|
|
99
|
+
(ModelProviderName.azure_openai, "azure", False),
|
|
100
|
+
(ModelProviderName.huggingface, "huggingface", False),
|
|
101
|
+
(ModelProviderName.vertex, "vertex_ai", False),
|
|
102
|
+
(ModelProviderName.together_ai, "together_ai", False),
|
|
103
|
+
(ModelProviderName.ollama, "openai", True),
|
|
104
|
+
(ModelProviderName.openai_compatible, "openai", True),
|
|
105
|
+
(ModelProviderName.kiln_custom_registry, "openai", True),
|
|
106
|
+
(ModelProviderName.kiln_fine_tune, "openai", True),
|
|
107
|
+
],
|
|
108
|
+
)
|
|
109
|
+
def test_provider_mappings_with_embedding_provider(
|
|
110
|
+
self, provider_name, expected_litellm_name, expected_is_custom, sample_model_id
|
|
111
|
+
):
|
|
112
|
+
"""Test provider name mappings for KilnEmbeddingModelProvider"""
|
|
113
|
+
provider = KilnEmbeddingModelProvider(
|
|
114
|
+
name=provider_name,
|
|
115
|
+
model_id=sample_model_id,
|
|
116
|
+
n_dimensions=1536,
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
result = get_litellm_provider_info(provider)
|
|
120
|
+
|
|
121
|
+
assert isinstance(result, LitellmProviderInfo)
|
|
122
|
+
assert result.provider_name == expected_litellm_name
|
|
123
|
+
assert result.is_custom == expected_is_custom
|
|
124
|
+
assert result.litellm_model_id == f"{expected_litellm_name}/{sample_model_id}"
|
|
125
|
+
|
|
126
|
+
def test_custom_providers_use_openai_format(self, sample_model_id):
|
|
127
|
+
"""Test that custom providers use 'openai' as the litellm provider name"""
|
|
128
|
+
custom_providers = [
|
|
129
|
+
ModelProviderName.ollama,
|
|
130
|
+
ModelProviderName.openai_compatible,
|
|
131
|
+
ModelProviderName.kiln_custom_registry,
|
|
132
|
+
ModelProviderName.kiln_fine_tune,
|
|
133
|
+
]
|
|
134
|
+
|
|
135
|
+
for provider_name in custom_providers:
|
|
136
|
+
provider = KilnModelProvider(
|
|
137
|
+
name=provider_name,
|
|
138
|
+
model_id=sample_model_id,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
result = get_litellm_provider_info(provider)
|
|
142
|
+
|
|
143
|
+
assert result.provider_name == "openai"
|
|
144
|
+
assert result.is_custom is True
|
|
145
|
+
assert result.litellm_model_id == f"openai/{sample_model_id}"
|
|
146
|
+
|
|
147
|
+
def test_non_custom_providers_use_correct_format(self, sample_model_id):
|
|
148
|
+
"""Test that non-custom providers use their actual provider names"""
|
|
149
|
+
non_custom_providers = [
|
|
150
|
+
(ModelProviderName.openai, "openai"),
|
|
151
|
+
(ModelProviderName.anthropic, "anthropic"),
|
|
152
|
+
(ModelProviderName.groq, "groq"),
|
|
153
|
+
(ModelProviderName.gemini_api, "gemini"),
|
|
154
|
+
]
|
|
155
|
+
|
|
156
|
+
for provider_name, expected_name in non_custom_providers:
|
|
157
|
+
provider = KilnModelProvider(
|
|
158
|
+
name=provider_name,
|
|
159
|
+
model_id=sample_model_id,
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
result = get_litellm_provider_info(provider)
|
|
163
|
+
|
|
164
|
+
assert result.provider_name == expected_name
|
|
165
|
+
assert result.is_custom is False
|
|
166
|
+
assert result.litellm_model_id == f"{expected_name}/{sample_model_id}"
|
|
167
|
+
|
|
168
|
+
def test_litellm_model_id_format(self, embedding_provider):
|
|
169
|
+
"""Test that litellm_model_id follows the correct format"""
|
|
170
|
+
result = get_litellm_provider_info(embedding_provider)
|
|
171
|
+
|
|
172
|
+
expected_format = f"{result.provider_name}/{embedding_provider.model_id}"
|
|
173
|
+
assert result.litellm_model_id == expected_format
|
|
174
|
+
|
|
175
|
+
def test_return_type_structure(self, model_provider):
|
|
176
|
+
"""Test that the return type has all expected fields"""
|
|
177
|
+
result = get_litellm_provider_info(model_provider)
|
|
178
|
+
|
|
179
|
+
assert hasattr(result, "provider_name")
|
|
180
|
+
assert hasattr(result, "is_custom")
|
|
181
|
+
assert hasattr(result, "litellm_model_id")
|
|
182
|
+
|
|
183
|
+
assert isinstance(result.provider_name, str)
|
|
184
|
+
assert isinstance(result.is_custom, bool)
|
|
185
|
+
assert isinstance(result.litellm_model_id, str)
|
|
186
|
+
|
|
187
|
+
def test_works_with_both_provider_types(self, sample_model_id):
|
|
188
|
+
"""Test that function works with both KilnModelProvider and KilnEmbeddingModelProvider"""
|
|
189
|
+
model_provider = KilnModelProvider(
|
|
190
|
+
name=ModelProviderName.openai,
|
|
191
|
+
model_id=sample_model_id,
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
embedding_provider = KilnEmbeddingModelProvider(
|
|
195
|
+
name=ModelProviderName.openai,
|
|
196
|
+
model_id=sample_model_id,
|
|
197
|
+
n_dimensions=1536,
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
model_result = get_litellm_provider_info(model_provider)
|
|
201
|
+
embedding_result = get_litellm_provider_info(embedding_provider)
|
|
202
|
+
|
|
203
|
+
# Results should be identical for same provider name and model ID
|
|
204
|
+
assert model_result.provider_name == embedding_result.provider_name
|
|
205
|
+
assert model_result.is_custom == embedding_result.is_custom
|
|
206
|
+
assert model_result.litellm_model_id == embedding_result.litellm_model_id
|