kiln-ai 0.20.1__py3-none-any.whl → 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +6 -0
- kiln_ai/adapters/adapter_registry.py +43 -226
- kiln_ai/adapters/chunkers/__init__.py +13 -0
- kiln_ai/adapters/chunkers/base_chunker.py +42 -0
- kiln_ai/adapters/chunkers/chunker_registry.py +16 -0
- kiln_ai/adapters/chunkers/fixed_window_chunker.py +39 -0
- kiln_ai/adapters/chunkers/helpers.py +23 -0
- kiln_ai/adapters/chunkers/test_base_chunker.py +63 -0
- kiln_ai/adapters/chunkers/test_chunker_registry.py +28 -0
- kiln_ai/adapters/chunkers/test_fixed_window_chunker.py +346 -0
- kiln_ai/adapters/chunkers/test_helpers.py +75 -0
- kiln_ai/adapters/data_gen/test_data_gen_task.py +9 -3
- kiln_ai/adapters/embedding/__init__.py +0 -0
- kiln_ai/adapters/embedding/base_embedding_adapter.py +44 -0
- kiln_ai/adapters/embedding/embedding_registry.py +32 -0
- kiln_ai/adapters/embedding/litellm_embedding_adapter.py +199 -0
- kiln_ai/adapters/embedding/test_base_embedding_adapter.py +283 -0
- kiln_ai/adapters/embedding/test_embedding_registry.py +166 -0
- kiln_ai/adapters/embedding/test_litellm_embedding_adapter.py +1149 -0
- kiln_ai/adapters/eval/eval_runner.py +6 -2
- kiln_ai/adapters/eval/test_base_eval.py +1 -3
- kiln_ai/adapters/eval/test_g_eval.py +1 -1
- kiln_ai/adapters/extractors/__init__.py +18 -0
- kiln_ai/adapters/extractors/base_extractor.py +72 -0
- kiln_ai/adapters/extractors/encoding.py +20 -0
- kiln_ai/adapters/extractors/extractor_registry.py +44 -0
- kiln_ai/adapters/extractors/extractor_runner.py +112 -0
- kiln_ai/adapters/extractors/litellm_extractor.py +386 -0
- kiln_ai/adapters/extractors/test_base_extractor.py +244 -0
- kiln_ai/adapters/extractors/test_encoding.py +54 -0
- kiln_ai/adapters/extractors/test_extractor_registry.py +181 -0
- kiln_ai/adapters/extractors/test_extractor_runner.py +181 -0
- kiln_ai/adapters/extractors/test_litellm_extractor.py +1192 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +2 -2
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +2 -6
- kiln_ai/adapters/fine_tune/test_together_finetune.py +2 -6
- kiln_ai/adapters/ml_embedding_model_list.py +192 -0
- kiln_ai/adapters/ml_model_list.py +382 -4
- kiln_ai/adapters/model_adapters/litellm_adapter.py +7 -69
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +1 -1
- kiln_ai/adapters/model_adapters/test_structured_output.py +3 -1
- kiln_ai/adapters/ollama_tools.py +69 -12
- kiln_ai/adapters/provider_tools.py +190 -46
- kiln_ai/adapters/rag/deduplication.py +49 -0
- kiln_ai/adapters/rag/progress.py +252 -0
- kiln_ai/adapters/rag/rag_runners.py +844 -0
- kiln_ai/adapters/rag/test_deduplication.py +195 -0
- kiln_ai/adapters/rag/test_progress.py +785 -0
- kiln_ai/adapters/rag/test_rag_runners.py +2376 -0
- kiln_ai/adapters/remote_config.py +80 -8
- kiln_ai/adapters/test_adapter_registry.py +579 -86
- kiln_ai/adapters/test_ml_embedding_model_list.py +429 -0
- kiln_ai/adapters/test_ml_model_list.py +212 -0
- kiln_ai/adapters/test_ollama_tools.py +340 -1
- kiln_ai/adapters/test_prompt_builders.py +1 -1
- kiln_ai/adapters/test_provider_tools.py +199 -8
- kiln_ai/adapters/test_remote_config.py +551 -56
- kiln_ai/adapters/vector_store/__init__.py +1 -0
- kiln_ai/adapters/vector_store/base_vector_store_adapter.py +83 -0
- kiln_ai/adapters/vector_store/lancedb_adapter.py +389 -0
- kiln_ai/adapters/vector_store/test_base_vector_store.py +160 -0
- kiln_ai/adapters/vector_store/test_lancedb_adapter.py +1841 -0
- kiln_ai/adapters/vector_store/test_vector_store_registry.py +199 -0
- kiln_ai/adapters/vector_store/vector_store_registry.py +33 -0
- kiln_ai/datamodel/__init__.py +16 -13
- kiln_ai/datamodel/basemodel.py +170 -1
- kiln_ai/datamodel/chunk.py +158 -0
- kiln_ai/datamodel/datamodel_enums.py +27 -0
- kiln_ai/datamodel/embedding.py +64 -0
- kiln_ai/datamodel/extraction.py +303 -0
- kiln_ai/datamodel/project.py +33 -1
- kiln_ai/datamodel/rag.py +79 -0
- kiln_ai/datamodel/test_attachment.py +649 -0
- kiln_ai/datamodel/test_basemodel.py +1 -1
- kiln_ai/datamodel/test_chunk_models.py +317 -0
- kiln_ai/datamodel/test_dataset_split.py +1 -1
- kiln_ai/datamodel/test_embedding_models.py +448 -0
- kiln_ai/datamodel/test_eval_model.py +6 -6
- kiln_ai/datamodel/test_extraction_chunk.py +206 -0
- kiln_ai/datamodel/test_extraction_model.py +470 -0
- kiln_ai/datamodel/test_rag.py +641 -0
- kiln_ai/datamodel/test_tool_id.py +81 -0
- kiln_ai/datamodel/test_vector_store.py +320 -0
- kiln_ai/datamodel/tool_id.py +22 -0
- kiln_ai/datamodel/vector_store.py +141 -0
- kiln_ai/tools/mcp_session_manager.py +4 -1
- kiln_ai/tools/rag_tools.py +157 -0
- kiln_ai/tools/test_mcp_session_manager.py +1 -1
- kiln_ai/tools/test_rag_tools.py +848 -0
- kiln_ai/tools/test_tool_registry.py +91 -2
- kiln_ai/tools/tool_registry.py +21 -0
- kiln_ai/utils/__init__.py +3 -0
- kiln_ai/utils/async_job_runner.py +62 -17
- kiln_ai/utils/config.py +2 -2
- kiln_ai/utils/env.py +15 -0
- kiln_ai/utils/filesystem.py +14 -0
- kiln_ai/utils/filesystem_cache.py +60 -0
- kiln_ai/utils/litellm.py +94 -0
- kiln_ai/utils/lock.py +100 -0
- kiln_ai/utils/mime_type.py +38 -0
- kiln_ai/utils/pdf_utils.py +38 -0
- kiln_ai/utils/test_async_job_runner.py +151 -35
- kiln_ai/utils/test_env.py +142 -0
- kiln_ai/utils/test_filesystem_cache.py +316 -0
- kiln_ai/utils/test_litellm.py +206 -0
- kiln_ai/utils/test_lock.py +185 -0
- kiln_ai/utils/test_mime_type.py +66 -0
- kiln_ai/utils/test_pdf_utils.py +73 -0
- kiln_ai/utils/test_uuid.py +111 -0
- kiln_ai/utils/test_validation.py +524 -0
- kiln_ai/utils/uuid.py +9 -0
- kiln_ai/utils/validation.py +90 -0
- {kiln_ai-0.20.1.dist-info → kiln_ai-0.21.0.dist-info}/METADATA +7 -1
- kiln_ai-0.21.0.dist-info/RECORD +211 -0
- kiln_ai-0.20.1.dist-info/RECORD +0 -138
- {kiln_ai-0.20.1.dist-info → kiln_ai-0.21.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.20.1.dist-info → kiln_ai-0.21.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -4,10 +4,12 @@ from pydantic import BaseModel, ValidationError
|
|
|
4
4
|
from kiln_ai.datamodel.tool_id import (
|
|
5
5
|
MCP_LOCAL_TOOL_ID_PREFIX,
|
|
6
6
|
MCP_REMOTE_TOOL_ID_PREFIX,
|
|
7
|
+
RAG_TOOL_ID_PREFIX,
|
|
7
8
|
KilnBuiltInToolId,
|
|
8
9
|
ToolId,
|
|
9
10
|
_check_tool_id,
|
|
10
11
|
mcp_server_and_tool_name_from_id,
|
|
12
|
+
rag_config_id_from_id,
|
|
11
13
|
)
|
|
12
14
|
|
|
13
15
|
|
|
@@ -113,6 +115,36 @@ class TestCheckToolId:
|
|
|
113
115
|
with pytest.raises(ValueError, match="Invalid tool ID"):
|
|
114
116
|
_check_tool_id("mcp::wrong::server::tool")
|
|
115
117
|
|
|
118
|
+
def test_valid_rag_tools(self):
|
|
119
|
+
"""Test validation of valid RAG tools."""
|
|
120
|
+
valid_ids = [
|
|
121
|
+
"kiln_tool::rag::config1",
|
|
122
|
+
"kiln_tool::rag::my_rag_config",
|
|
123
|
+
"kiln_tool::rag::test_config_123",
|
|
124
|
+
]
|
|
125
|
+
for tool_id in valid_ids:
|
|
126
|
+
result = _check_tool_id(tool_id)
|
|
127
|
+
assert result == tool_id
|
|
128
|
+
|
|
129
|
+
def test_invalid_rag_format(self):
|
|
130
|
+
"""Test validation fails for invalid RAG tool formats."""
|
|
131
|
+
# These IDs start with the RAG prefix but have invalid formats
|
|
132
|
+
rag_invalid_ids = [
|
|
133
|
+
"kiln_tool::rag::", # Missing config ID
|
|
134
|
+
"kiln_tool::rag::config::extra", # Too many parts
|
|
135
|
+
]
|
|
136
|
+
|
|
137
|
+
for invalid_id in rag_invalid_ids:
|
|
138
|
+
with pytest.raises(ValueError, match="Invalid RAG tool ID"):
|
|
139
|
+
_check_tool_id(invalid_id)
|
|
140
|
+
|
|
141
|
+
def test_rag_tool_empty_config_id(self):
|
|
142
|
+
"""Test that RAG tool with empty config ID is handled properly."""
|
|
143
|
+
# This tests the case where rag_config_id_from_id returns empty string
|
|
144
|
+
# which should trigger line 66 in the source
|
|
145
|
+
with pytest.raises(ValueError, match="Invalid RAG tool ID"):
|
|
146
|
+
_check_tool_id("kiln_tool::rag::")
|
|
147
|
+
|
|
116
148
|
|
|
117
149
|
class TestMcpServerAndToolNameFromId:
|
|
118
150
|
"""Test the mcp_server_and_tool_name_from_id function."""
|
|
@@ -197,6 +229,9 @@ class TestToolIdPydanticType:
|
|
|
197
229
|
# Local MCP tools
|
|
198
230
|
"mcp::local::server1::tool1",
|
|
199
231
|
"mcp::local::my_server::my_tool",
|
|
232
|
+
# RAG tools
|
|
233
|
+
"kiln_tool::rag::config1",
|
|
234
|
+
"kiln_tool::rag::my_rag_config",
|
|
200
235
|
]
|
|
201
236
|
|
|
202
237
|
for tool_id in valid_ids:
|
|
@@ -212,6 +247,8 @@ class TestToolIdPydanticType:
|
|
|
212
247
|
"mcp::remote::server",
|
|
213
248
|
"mcp::local::",
|
|
214
249
|
"mcp::local::server",
|
|
250
|
+
"kiln_tool::rag::",
|
|
251
|
+
"kiln_tool::rag::config::extra",
|
|
215
252
|
]
|
|
216
253
|
|
|
217
254
|
for invalid_id in invalid_ids:
|
|
@@ -237,3 +274,47 @@ class TestConstants:
|
|
|
237
274
|
def test_mcp_local_tool_id_prefix(self):
|
|
238
275
|
"""Test the MCP local tool ID prefix constant."""
|
|
239
276
|
assert MCP_LOCAL_TOOL_ID_PREFIX == "mcp::local::"
|
|
277
|
+
|
|
278
|
+
def test_rag_tool_id_prefix(self):
|
|
279
|
+
"""Test the RAG tool ID prefix constant."""
|
|
280
|
+
assert RAG_TOOL_ID_PREFIX == "kiln_tool::rag::"
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
class TestRagConfigIdFromId:
|
|
284
|
+
"""Test the rag_config_id_from_id function."""
|
|
285
|
+
|
|
286
|
+
def test_valid_rag_ids(self):
|
|
287
|
+
"""Test parsing valid RAG tool IDs."""
|
|
288
|
+
test_cases = [
|
|
289
|
+
("kiln_tool::rag::config1", "config1"),
|
|
290
|
+
("kiln_tool::rag::my_rag_config", "my_rag_config"),
|
|
291
|
+
("kiln_tool::rag::test_config_123", "test_config_123"),
|
|
292
|
+
("kiln_tool::rag::a", "a"), # Minimal valid case
|
|
293
|
+
]
|
|
294
|
+
|
|
295
|
+
for tool_id, expected in test_cases:
|
|
296
|
+
result = rag_config_id_from_id(tool_id)
|
|
297
|
+
assert result == expected
|
|
298
|
+
|
|
299
|
+
def test_invalid_rag_ids(self):
|
|
300
|
+
"""Test parsing fails for invalid RAG tool IDs."""
|
|
301
|
+
# Test various invalid formats that should trigger line 104
|
|
302
|
+
invalid_ids = [
|
|
303
|
+
"kiln_tool::rag::config::extra", # Too many parts (4 parts)
|
|
304
|
+
"wrong::rag::config", # Wrong prefix
|
|
305
|
+
"kiln_tool::wrong::config", # Wrong middle part
|
|
306
|
+
"rag::config", # Too few parts (2 parts)
|
|
307
|
+
"", # Empty string
|
|
308
|
+
"single_part", # Only 1 part
|
|
309
|
+
]
|
|
310
|
+
|
|
311
|
+
for invalid_id in invalid_ids:
|
|
312
|
+
with pytest.raises(ValueError, match="Invalid RAG tool ID"):
|
|
313
|
+
rag_config_id_from_id(invalid_id)
|
|
314
|
+
|
|
315
|
+
def test_rag_id_with_empty_config_id(self):
|
|
316
|
+
"""Test that RAG tool ID with empty config ID returns empty string."""
|
|
317
|
+
# This is actually valid according to the parser - it returns empty string
|
|
318
|
+
# The validation for empty config ID happens in _check_tool_id
|
|
319
|
+
result = rag_config_id_from_id("kiln_tool::rag::")
|
|
320
|
+
assert result == ""
|
|
@@ -0,0 +1,320 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
from pydantic import ValidationError
|
|
3
|
+
|
|
4
|
+
from kiln_ai.datamodel.project import Project
|
|
5
|
+
from kiln_ai.datamodel.vector_store import (
|
|
6
|
+
LanceDBConfigBaseProperties,
|
|
7
|
+
VectorStoreConfig,
|
|
8
|
+
VectorStoreType,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@pytest.fixture
|
|
13
|
+
def mock_project(tmp_path):
|
|
14
|
+
project_path = tmp_path / "test_project" / "project.kiln"
|
|
15
|
+
project_path.parent.mkdir()
|
|
16
|
+
|
|
17
|
+
project = Project(name="Test Project", path=project_path)
|
|
18
|
+
project.save_to_file()
|
|
19
|
+
|
|
20
|
+
return project
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@pytest.fixture
|
|
24
|
+
def mock_vector_store_fts_config_properties():
|
|
25
|
+
return {
|
|
26
|
+
"similarity_top_k": 10,
|
|
27
|
+
"overfetch_factor": 2,
|
|
28
|
+
"vector_column_name": "vector",
|
|
29
|
+
"text_key": "text",
|
|
30
|
+
"doc_id_key": "doc_id",
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@pytest.fixture
|
|
35
|
+
def mock_vector_store_vector_config_properties():
|
|
36
|
+
return {
|
|
37
|
+
"similarity_top_k": 10,
|
|
38
|
+
"overfetch_factor": 2,
|
|
39
|
+
"vector_column_name": "vector",
|
|
40
|
+
"text_key": "text",
|
|
41
|
+
"doc_id_key": "doc_id",
|
|
42
|
+
"nprobes": 1,
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class TestVectorStoreType:
|
|
47
|
+
def test_vector_store_type_values(self):
|
|
48
|
+
"""Test that VectorStoreType enum has expected values."""
|
|
49
|
+
assert VectorStoreType.LANCE_DB_FTS == "lancedb_fts"
|
|
50
|
+
assert VectorStoreType.LANCE_DB_HYBRID == "lancedb_hybrid"
|
|
51
|
+
assert VectorStoreType.LANCE_DB_VECTOR == "lancedb_vector"
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class TestLanceDBConfigBaseProperties:
|
|
55
|
+
def test_valid_lance_db_config_base_properties(self):
|
|
56
|
+
"""Test creating valid LanceDBConfigBaseProperties."""
|
|
57
|
+
config = LanceDBConfigBaseProperties(
|
|
58
|
+
similarity_top_k=10,
|
|
59
|
+
overfetch_factor=2,
|
|
60
|
+
vector_column_name="vector",
|
|
61
|
+
text_key="text",
|
|
62
|
+
doc_id_key="doc_id",
|
|
63
|
+
nprobes=1,
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
assert config.similarity_top_k == 10
|
|
67
|
+
assert config.overfetch_factor == 2
|
|
68
|
+
assert config.vector_column_name == "vector"
|
|
69
|
+
assert config.text_key == "text"
|
|
70
|
+
assert config.doc_id_key == "doc_id"
|
|
71
|
+
assert config.nprobes == 1
|
|
72
|
+
|
|
73
|
+
def test_lance_db_config_base_properties_without_nprobes(self):
|
|
74
|
+
"""Test creating LanceDBConfigBaseProperties without nprobes."""
|
|
75
|
+
config = LanceDBConfigBaseProperties(
|
|
76
|
+
similarity_top_k=10,
|
|
77
|
+
overfetch_factor=2,
|
|
78
|
+
vector_column_name="vector",
|
|
79
|
+
text_key="text",
|
|
80
|
+
doc_id_key="doc_id",
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
assert config.similarity_top_k == 10
|
|
84
|
+
assert config.nprobes is None
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class TestVectorStoreConfig:
|
|
88
|
+
def test_invalid_store_type(self, mock_vector_store_fts_config_properties):
|
|
89
|
+
"""Test creating VectorStoreConfig with invalid store type."""
|
|
90
|
+
with pytest.raises(ValidationError, match="Input should be"):
|
|
91
|
+
VectorStoreConfig(
|
|
92
|
+
name="test_store",
|
|
93
|
+
store_type="invalid_type", # type: ignore
|
|
94
|
+
properties=mock_vector_store_fts_config_properties,
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
def test_invalid_store_type_after_creation(
|
|
98
|
+
self, mock_vector_store_fts_config_properties
|
|
99
|
+
):
|
|
100
|
+
"""Test creating VectorStoreConfig with invalid store type after creation."""
|
|
101
|
+
config = VectorStoreConfig(
|
|
102
|
+
name="test_store",
|
|
103
|
+
store_type=VectorStoreType.LANCE_DB_FTS,
|
|
104
|
+
properties=mock_vector_store_fts_config_properties,
|
|
105
|
+
)
|
|
106
|
+
with pytest.raises(ValidationError, match="Input should be"):
|
|
107
|
+
config.store_type = "invalid_type" # type: ignore
|
|
108
|
+
|
|
109
|
+
def test_valid_lance_db_fts_vector_store_config(
|
|
110
|
+
self, mock_vector_store_fts_config_properties
|
|
111
|
+
):
|
|
112
|
+
"""Test creating valid VectorStoreConfig with LanceDB FTS."""
|
|
113
|
+
config = VectorStoreConfig(
|
|
114
|
+
name="test_store",
|
|
115
|
+
store_type=VectorStoreType.LANCE_DB_FTS,
|
|
116
|
+
properties=mock_vector_store_fts_config_properties,
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
assert config.name == "test_store"
|
|
120
|
+
assert config.store_type == VectorStoreType.LANCE_DB_FTS
|
|
121
|
+
assert config.properties["similarity_top_k"] == 10
|
|
122
|
+
assert config.properties["overfetch_factor"] == 2
|
|
123
|
+
assert config.properties["vector_column_name"] == "vector"
|
|
124
|
+
assert config.properties["text_key"] == "text"
|
|
125
|
+
assert config.properties["doc_id_key"] == "doc_id"
|
|
126
|
+
|
|
127
|
+
def test_valid_lance_db_vector_store_config(
|
|
128
|
+
self, mock_vector_store_vector_config_properties
|
|
129
|
+
):
|
|
130
|
+
"""Test creating valid VectorStoreConfig with LanceDB Vector."""
|
|
131
|
+
config = VectorStoreConfig(
|
|
132
|
+
name="test_store",
|
|
133
|
+
store_type=VectorStoreType.LANCE_DB_VECTOR,
|
|
134
|
+
properties=mock_vector_store_vector_config_properties,
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
assert config.name == "test_store"
|
|
138
|
+
assert config.store_type == VectorStoreType.LANCE_DB_VECTOR
|
|
139
|
+
assert config.properties["similarity_top_k"] == 10
|
|
140
|
+
assert config.properties["nprobes"] == 1
|
|
141
|
+
|
|
142
|
+
def test_valid_lance_db_hybrid_store_config(
|
|
143
|
+
self, mock_vector_store_vector_config_properties
|
|
144
|
+
):
|
|
145
|
+
"""Test creating valid VectorStoreConfig with LanceDB Hybrid."""
|
|
146
|
+
config = VectorStoreConfig(
|
|
147
|
+
name="test_store",
|
|
148
|
+
store_type=VectorStoreType.LANCE_DB_HYBRID,
|
|
149
|
+
properties=mock_vector_store_vector_config_properties,
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
assert config.name == "test_store"
|
|
153
|
+
assert config.store_type == VectorStoreType.LANCE_DB_HYBRID
|
|
154
|
+
assert config.properties["nprobes"] == 1
|
|
155
|
+
|
|
156
|
+
def test_vector_store_config_missing_required_property(
|
|
157
|
+
self, mock_vector_store_fts_config_properties
|
|
158
|
+
):
|
|
159
|
+
"""Test VectorStoreConfig validation fails when required property is missing."""
|
|
160
|
+
mock_vector_store_fts_config_properties.pop("similarity_top_k")
|
|
161
|
+
with pytest.raises(
|
|
162
|
+
ValidationError,
|
|
163
|
+
match=r".*similarity_top_k is a required property",
|
|
164
|
+
):
|
|
165
|
+
VectorStoreConfig(
|
|
166
|
+
name="test_store",
|
|
167
|
+
store_type=VectorStoreType.LANCE_DB_FTS,
|
|
168
|
+
properties=mock_vector_store_fts_config_properties,
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
def test_vector_store_config_invalid_property_type(
|
|
172
|
+
self, mock_vector_store_fts_config_properties
|
|
173
|
+
):
|
|
174
|
+
"""Test VectorStoreConfig validation fails when property has wrong type."""
|
|
175
|
+
mock_vector_store_fts_config_properties["similarity_top_k"] = "not_an_int"
|
|
176
|
+
with pytest.raises(
|
|
177
|
+
ValidationError,
|
|
178
|
+
match=r".*similarity_top_k must be of type",
|
|
179
|
+
):
|
|
180
|
+
VectorStoreConfig(
|
|
181
|
+
name="test_store",
|
|
182
|
+
store_type=VectorStoreType.LANCE_DB_FTS,
|
|
183
|
+
properties=mock_vector_store_fts_config_properties,
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
def test_vector_store_config_fts_missing_nprobes_is_valid(
|
|
187
|
+
self, mock_vector_store_fts_config_properties
|
|
188
|
+
):
|
|
189
|
+
"""Test VectorStoreConfig with FTS type doesn't require nprobes."""
|
|
190
|
+
config = VectorStoreConfig(
|
|
191
|
+
name="test_store",
|
|
192
|
+
store_type=VectorStoreType.LANCE_DB_FTS,
|
|
193
|
+
properties=mock_vector_store_fts_config_properties,
|
|
194
|
+
)
|
|
195
|
+
assert config.store_type == VectorStoreType.LANCE_DB_FTS
|
|
196
|
+
|
|
197
|
+
def test_vector_store_config_vector_missing_nprobes_fails(
|
|
198
|
+
self, mock_vector_store_vector_config_properties
|
|
199
|
+
):
|
|
200
|
+
"""Test VectorStoreConfig with VECTOR type requires nprobes."""
|
|
201
|
+
mock_vector_store_vector_config_properties.pop("nprobes")
|
|
202
|
+
with pytest.raises(
|
|
203
|
+
ValidationError,
|
|
204
|
+
match=r".*nprobes is a required property",
|
|
205
|
+
):
|
|
206
|
+
VectorStoreConfig(
|
|
207
|
+
name="test_store",
|
|
208
|
+
store_type=VectorStoreType.LANCE_DB_VECTOR,
|
|
209
|
+
properties=mock_vector_store_vector_config_properties,
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
def test_lancedb_properties(self, mock_vector_store_vector_config_properties):
|
|
213
|
+
"""Test lancedb_properties method returns correct LanceDBConfigBaseProperties."""
|
|
214
|
+
config = VectorStoreConfig(
|
|
215
|
+
name="test_store",
|
|
216
|
+
store_type=VectorStoreType.LANCE_DB_VECTOR,
|
|
217
|
+
properties=mock_vector_store_vector_config_properties,
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
props = config.lancedb_properties
|
|
221
|
+
|
|
222
|
+
assert isinstance(props, LanceDBConfigBaseProperties)
|
|
223
|
+
assert props.similarity_top_k == 10
|
|
224
|
+
assert props.overfetch_factor == 2
|
|
225
|
+
assert props.vector_column_name == "vector"
|
|
226
|
+
assert props.text_key == "text"
|
|
227
|
+
assert props.doc_id_key == "doc_id"
|
|
228
|
+
assert props.nprobes == 1
|
|
229
|
+
|
|
230
|
+
def test_vector_store_config_inherits_from_kiln_parented_model(
|
|
231
|
+
self, mock_vector_store_fts_config_properties
|
|
232
|
+
):
|
|
233
|
+
"""Test that VectorStoreConfig inherits from KilnParentedModel."""
|
|
234
|
+
config = VectorStoreConfig(
|
|
235
|
+
name="test_store",
|
|
236
|
+
store_type=VectorStoreType.LANCE_DB_FTS,
|
|
237
|
+
properties=mock_vector_store_fts_config_properties,
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
# Check that it has the expected base fields
|
|
241
|
+
assert hasattr(config, "id")
|
|
242
|
+
assert hasattr(config, "v")
|
|
243
|
+
assert hasattr(config, "created_at")
|
|
244
|
+
assert hasattr(config, "created_by")
|
|
245
|
+
assert hasattr(config, "parent")
|
|
246
|
+
|
|
247
|
+
@pytest.mark.parametrize(
|
|
248
|
+
"name",
|
|
249
|
+
["valid_name", "valid name", "valid-name", "valid_name_123", "VALID_NAME"],
|
|
250
|
+
)
|
|
251
|
+
def test_vector_store_config_valid_names(
|
|
252
|
+
self, name, mock_vector_store_fts_config_properties
|
|
253
|
+
):
|
|
254
|
+
"""Test VectorStoreConfig accepts valid names."""
|
|
255
|
+
config = VectorStoreConfig(
|
|
256
|
+
name=name,
|
|
257
|
+
store_type=VectorStoreType.LANCE_DB_FTS,
|
|
258
|
+
properties=mock_vector_store_fts_config_properties,
|
|
259
|
+
)
|
|
260
|
+
assert config.name == name
|
|
261
|
+
|
|
262
|
+
@pytest.mark.parametrize(
|
|
263
|
+
"name",
|
|
264
|
+
[
|
|
265
|
+
"",
|
|
266
|
+
"a" * 121, # Too long
|
|
267
|
+
],
|
|
268
|
+
)
|
|
269
|
+
def test_vector_store_config_invalid_names(
|
|
270
|
+
self, name, mock_vector_store_fts_config_properties
|
|
271
|
+
):
|
|
272
|
+
"""Test VectorStoreConfig rejects invalid names."""
|
|
273
|
+
with pytest.raises(ValidationError):
|
|
274
|
+
VectorStoreConfig(
|
|
275
|
+
name=name,
|
|
276
|
+
store_type=VectorStoreType.LANCE_DB_FTS,
|
|
277
|
+
properties=mock_vector_store_fts_config_properties,
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
def test_parent_project(
|
|
281
|
+
self, mock_project, mock_vector_store_fts_config_properties
|
|
282
|
+
):
|
|
283
|
+
"""Test that parent project is returned correctly."""
|
|
284
|
+
config = VectorStoreConfig(
|
|
285
|
+
name="test_store",
|
|
286
|
+
store_type=VectorStoreType.LANCE_DB_FTS,
|
|
287
|
+
properties=mock_vector_store_fts_config_properties,
|
|
288
|
+
parent=mock_project,
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
assert config.parent_project() is mock_project
|
|
292
|
+
|
|
293
|
+
def test_vector_store_config_parent_project_none(
|
|
294
|
+
self, mock_vector_store_fts_config_properties
|
|
295
|
+
):
|
|
296
|
+
"""Test that parent project is None if not set."""
|
|
297
|
+
config = VectorStoreConfig(
|
|
298
|
+
name="test_store",
|
|
299
|
+
store_type=VectorStoreType.LANCE_DB_FTS,
|
|
300
|
+
properties=mock_vector_store_fts_config_properties,
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
assert config.parent_project() is None
|
|
304
|
+
|
|
305
|
+
def test_project_has_vector_store_configs(
|
|
306
|
+
self, mock_project, mock_vector_store_fts_config_properties
|
|
307
|
+
):
|
|
308
|
+
"""Test that project has vector store configs."""
|
|
309
|
+
config = VectorStoreConfig(
|
|
310
|
+
name="test_store",
|
|
311
|
+
store_type=VectorStoreType.LANCE_DB_FTS,
|
|
312
|
+
properties=mock_vector_store_fts_config_properties,
|
|
313
|
+
parent=mock_project,
|
|
314
|
+
)
|
|
315
|
+
config.save_to_file()
|
|
316
|
+
|
|
317
|
+
assert len(mock_project.vector_store_configs(readonly=True)) == 1
|
|
318
|
+
assert config.id in [
|
|
319
|
+
vc.id for vc in mock_project.vector_store_configs(readonly=True)
|
|
320
|
+
]
|
kiln_ai/datamodel/tool_id.py
CHANGED
|
@@ -26,6 +26,7 @@ class KilnBuiltInToolId(str, Enum):
|
|
|
26
26
|
|
|
27
27
|
|
|
28
28
|
MCP_REMOTE_TOOL_ID_PREFIX = "mcp::remote::"
|
|
29
|
+
RAG_TOOL_ID_PREFIX = "kiln_tool::rag::"
|
|
29
30
|
MCP_LOCAL_TOOL_ID_PREFIX = "mcp::local::"
|
|
30
31
|
|
|
31
32
|
|
|
@@ -58,6 +59,15 @@ def _check_tool_id(id: str) -> str:
|
|
|
58
59
|
)
|
|
59
60
|
return id
|
|
60
61
|
|
|
62
|
+
# RAG tools must have format: kiln_tool::rag::<rag_config_id>
|
|
63
|
+
if id.startswith(RAG_TOOL_ID_PREFIX):
|
|
64
|
+
rag_config_id = rag_config_id_from_id(id)
|
|
65
|
+
if not rag_config_id:
|
|
66
|
+
raise ValueError(
|
|
67
|
+
f"Invalid RAG tool ID: {id}. Expected format: 'kiln_tool::rag::<rag_config_id>'."
|
|
68
|
+
)
|
|
69
|
+
return id
|
|
70
|
+
|
|
61
71
|
raise ValueError(f"Invalid tool ID: {id}")
|
|
62
72
|
|
|
63
73
|
|
|
@@ -81,3 +91,15 @@ def mcp_server_and_tool_name_from_id(id: str) -> tuple[str, str]:
|
|
|
81
91
|
f"Invalid MCP tool ID: {id}. Expected format: 'mcp::(remote|local)::<server_id>::<tool_name>'."
|
|
82
92
|
)
|
|
83
93
|
return parts[2], parts[3] # server_id, tool_name
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def rag_config_id_from_id(id: str) -> str:
|
|
97
|
+
"""
|
|
98
|
+
Get the RAG config ID from the ID.
|
|
99
|
+
"""
|
|
100
|
+
parts = id.split("::")
|
|
101
|
+
if not id.startswith(RAG_TOOL_ID_PREFIX) or len(parts) != 3:
|
|
102
|
+
raise ValueError(
|
|
103
|
+
f"Invalid RAG tool ID: {id}. Expected format: 'kiln_tool::rag::<rag_config_id>'."
|
|
104
|
+
)
|
|
105
|
+
return parts[2]
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
from typing import TYPE_CHECKING, Union
|
|
3
|
+
|
|
4
|
+
from pydantic import BaseModel, Field, model_validator
|
|
5
|
+
|
|
6
|
+
from kiln_ai.datamodel.basemodel import FilenameString, KilnParentedModel
|
|
7
|
+
from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
|
|
8
|
+
from kiln_ai.utils.validation import (
|
|
9
|
+
validate_return_dict_prop,
|
|
10
|
+
validate_return_dict_prop_optional,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from kiln_ai.datamodel.project import Project
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class VectorStoreType(str, Enum):
|
|
18
|
+
LANCE_DB_FTS = "lancedb_fts"
|
|
19
|
+
LANCE_DB_HYBRID = "lancedb_hybrid"
|
|
20
|
+
LANCE_DB_VECTOR = "lancedb_vector"
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class LanceDBConfigBaseProperties(BaseModel):
|
|
24
|
+
similarity_top_k: int = Field(
|
|
25
|
+
description="The number of results to return from the vector store.",
|
|
26
|
+
)
|
|
27
|
+
overfetch_factor: int = Field(
|
|
28
|
+
description="The overfetch factor to use for the vector search.",
|
|
29
|
+
)
|
|
30
|
+
vector_column_name: str = Field(
|
|
31
|
+
description="The name of the vector column in the vector store.",
|
|
32
|
+
)
|
|
33
|
+
text_key: str = Field(
|
|
34
|
+
description="The name of the text column in the vector store.",
|
|
35
|
+
)
|
|
36
|
+
doc_id_key: str = Field(
|
|
37
|
+
description="The name of the document id column in the vector store.",
|
|
38
|
+
)
|
|
39
|
+
nprobes: int | None = Field(
|
|
40
|
+
description="The number of probes to use for the vector search.",
|
|
41
|
+
default=None,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class VectorStoreConfig(KilnParentedModel):
|
|
46
|
+
name: FilenameString = Field(
|
|
47
|
+
description="A name for your own reference to identify the vector store config.",
|
|
48
|
+
)
|
|
49
|
+
description: str | None = Field(
|
|
50
|
+
description="A description for your own reference.",
|
|
51
|
+
default=None,
|
|
52
|
+
)
|
|
53
|
+
store_type: VectorStoreType = Field(
|
|
54
|
+
description="The type of vector store to use.",
|
|
55
|
+
)
|
|
56
|
+
properties: dict[str, str | int | float | None] = Field(
|
|
57
|
+
description="The properties of the vector store config, specific to the selected store_type.",
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
@model_validator(mode="after")
|
|
61
|
+
def validate_properties(self):
|
|
62
|
+
match self.store_type:
|
|
63
|
+
case (
|
|
64
|
+
VectorStoreType.LANCE_DB_FTS
|
|
65
|
+
| VectorStoreType.LANCE_DB_HYBRID
|
|
66
|
+
| VectorStoreType.LANCE_DB_VECTOR
|
|
67
|
+
):
|
|
68
|
+
return self.validate_lancedb_properties(self.store_type)
|
|
69
|
+
case _:
|
|
70
|
+
raise_exhaustive_enum_error(self.store_type)
|
|
71
|
+
|
|
72
|
+
def validate_lancedb_properties(self, store_type: VectorStoreType):
|
|
73
|
+
err_msg_prefix = f"LanceDB vector store configs properties for {store_type}:"
|
|
74
|
+
validate_return_dict_prop(
|
|
75
|
+
self.properties, "similarity_top_k", int, err_msg_prefix
|
|
76
|
+
)
|
|
77
|
+
validate_return_dict_prop(
|
|
78
|
+
self.properties, "overfetch_factor", int, err_msg_prefix
|
|
79
|
+
)
|
|
80
|
+
validate_return_dict_prop(
|
|
81
|
+
self.properties, "vector_column_name", str, err_msg_prefix
|
|
82
|
+
)
|
|
83
|
+
validate_return_dict_prop(self.properties, "text_key", str, err_msg_prefix)
|
|
84
|
+
validate_return_dict_prop(self.properties, "doc_id_key", str, err_msg_prefix)
|
|
85
|
+
|
|
86
|
+
# nprobes is only used for vector and hybrid queries
|
|
87
|
+
if (
|
|
88
|
+
store_type == VectorStoreType.LANCE_DB_VECTOR
|
|
89
|
+
or store_type == VectorStoreType.LANCE_DB_HYBRID
|
|
90
|
+
):
|
|
91
|
+
validate_return_dict_prop(self.properties, "nprobes", int, err_msg_prefix)
|
|
92
|
+
|
|
93
|
+
return self
|
|
94
|
+
|
|
95
|
+
@property
|
|
96
|
+
def lancedb_properties(self) -> LanceDBConfigBaseProperties:
|
|
97
|
+
err_msg_prefix = "LanceDB vector store configs properties:"
|
|
98
|
+
return LanceDBConfigBaseProperties(
|
|
99
|
+
similarity_top_k=validate_return_dict_prop(
|
|
100
|
+
self.properties,
|
|
101
|
+
"similarity_top_k",
|
|
102
|
+
int,
|
|
103
|
+
err_msg_prefix,
|
|
104
|
+
),
|
|
105
|
+
overfetch_factor=validate_return_dict_prop(
|
|
106
|
+
self.properties,
|
|
107
|
+
"overfetch_factor",
|
|
108
|
+
int,
|
|
109
|
+
err_msg_prefix,
|
|
110
|
+
),
|
|
111
|
+
vector_column_name=validate_return_dict_prop(
|
|
112
|
+
self.properties,
|
|
113
|
+
"vector_column_name",
|
|
114
|
+
str,
|
|
115
|
+
err_msg_prefix,
|
|
116
|
+
),
|
|
117
|
+
text_key=validate_return_dict_prop(
|
|
118
|
+
self.properties,
|
|
119
|
+
"text_key",
|
|
120
|
+
str,
|
|
121
|
+
err_msg_prefix,
|
|
122
|
+
),
|
|
123
|
+
doc_id_key=validate_return_dict_prop(
|
|
124
|
+
self.properties,
|
|
125
|
+
"doc_id_key",
|
|
126
|
+
str,
|
|
127
|
+
err_msg_prefix,
|
|
128
|
+
),
|
|
129
|
+
nprobes=validate_return_dict_prop_optional(
|
|
130
|
+
self.properties,
|
|
131
|
+
"nprobes",
|
|
132
|
+
int,
|
|
133
|
+
err_msg_prefix,
|
|
134
|
+
),
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
# Workaround to return typed parent without importing Project
|
|
138
|
+
def parent_project(self) -> Union["Project", None]:
|
|
139
|
+
if self.parent is None or self.parent.__class__.__name__ != "Project":
|
|
140
|
+
return None
|
|
141
|
+
return self.parent # type: ignore
|
|
@@ -3,6 +3,7 @@ import os
|
|
|
3
3
|
import subprocess
|
|
4
4
|
import sys
|
|
5
5
|
from contextlib import asynccontextmanager
|
|
6
|
+
from datetime import timedelta
|
|
6
7
|
from typing import AsyncGenerator
|
|
7
8
|
|
|
8
9
|
import httpx
|
|
@@ -171,7 +172,9 @@ class MCPSessionManager:
|
|
|
171
172
|
|
|
172
173
|
try:
|
|
173
174
|
async with stdio_client(server_params) as (read, write):
|
|
174
|
-
async with ClientSession(
|
|
175
|
+
async with ClientSession(
|
|
176
|
+
read, write, read_timeout_seconds=timedelta(seconds=8)
|
|
177
|
+
) as session:
|
|
175
178
|
await session.initialize()
|
|
176
179
|
yield session
|
|
177
180
|
except Exception as e:
|