kiln-ai 0.20.1__py3-none-any.whl → 0.22.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +6 -0
- kiln_ai/adapters/adapter_registry.py +43 -226
- kiln_ai/adapters/chunkers/__init__.py +13 -0
- kiln_ai/adapters/chunkers/base_chunker.py +42 -0
- kiln_ai/adapters/chunkers/chunker_registry.py +16 -0
- kiln_ai/adapters/chunkers/fixed_window_chunker.py +39 -0
- kiln_ai/adapters/chunkers/helpers.py +23 -0
- kiln_ai/adapters/chunkers/test_base_chunker.py +63 -0
- kiln_ai/adapters/chunkers/test_chunker_registry.py +28 -0
- kiln_ai/adapters/chunkers/test_fixed_window_chunker.py +346 -0
- kiln_ai/adapters/chunkers/test_helpers.py +75 -0
- kiln_ai/adapters/data_gen/test_data_gen_task.py +9 -3
- kiln_ai/adapters/embedding/__init__.py +0 -0
- kiln_ai/adapters/embedding/base_embedding_adapter.py +44 -0
- kiln_ai/adapters/embedding/embedding_registry.py +32 -0
- kiln_ai/adapters/embedding/litellm_embedding_adapter.py +199 -0
- kiln_ai/adapters/embedding/test_base_embedding_adapter.py +283 -0
- kiln_ai/adapters/embedding/test_embedding_registry.py +166 -0
- kiln_ai/adapters/embedding/test_litellm_embedding_adapter.py +1149 -0
- kiln_ai/adapters/eval/eval_runner.py +6 -2
- kiln_ai/adapters/eval/test_base_eval.py +1 -3
- kiln_ai/adapters/eval/test_g_eval.py +1 -1
- kiln_ai/adapters/extractors/__init__.py +18 -0
- kiln_ai/adapters/extractors/base_extractor.py +72 -0
- kiln_ai/adapters/extractors/encoding.py +20 -0
- kiln_ai/adapters/extractors/extractor_registry.py +44 -0
- kiln_ai/adapters/extractors/extractor_runner.py +112 -0
- kiln_ai/adapters/extractors/litellm_extractor.py +406 -0
- kiln_ai/adapters/extractors/test_base_extractor.py +244 -0
- kiln_ai/adapters/extractors/test_encoding.py +54 -0
- kiln_ai/adapters/extractors/test_extractor_registry.py +181 -0
- kiln_ai/adapters/extractors/test_extractor_runner.py +181 -0
- kiln_ai/adapters/extractors/test_litellm_extractor.py +1290 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +2 -2
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +2 -6
- kiln_ai/adapters/fine_tune/test_together_finetune.py +2 -6
- kiln_ai/adapters/ml_embedding_model_list.py +494 -0
- kiln_ai/adapters/ml_model_list.py +876 -18
- kiln_ai/adapters/model_adapters/litellm_adapter.py +40 -75
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +79 -1
- kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +119 -5
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +9 -3
- kiln_ai/adapters/model_adapters/test_structured_output.py +9 -10
- kiln_ai/adapters/ollama_tools.py +69 -12
- kiln_ai/adapters/provider_tools.py +190 -46
- kiln_ai/adapters/rag/deduplication.py +49 -0
- kiln_ai/adapters/rag/progress.py +252 -0
- kiln_ai/adapters/rag/rag_runners.py +844 -0
- kiln_ai/adapters/rag/test_deduplication.py +195 -0
- kiln_ai/adapters/rag/test_progress.py +785 -0
- kiln_ai/adapters/rag/test_rag_runners.py +2376 -0
- kiln_ai/adapters/remote_config.py +80 -8
- kiln_ai/adapters/test_adapter_registry.py +579 -86
- kiln_ai/adapters/test_ml_embedding_model_list.py +239 -0
- kiln_ai/adapters/test_ml_model_list.py +202 -0
- kiln_ai/adapters/test_ollama_tools.py +340 -1
- kiln_ai/adapters/test_prompt_builders.py +1 -1
- kiln_ai/adapters/test_provider_tools.py +199 -8
- kiln_ai/adapters/test_remote_config.py +551 -56
- kiln_ai/adapters/vector_store/__init__.py +1 -0
- kiln_ai/adapters/vector_store/base_vector_store_adapter.py +83 -0
- kiln_ai/adapters/vector_store/lancedb_adapter.py +389 -0
- kiln_ai/adapters/vector_store/test_base_vector_store.py +160 -0
- kiln_ai/adapters/vector_store/test_lancedb_adapter.py +1841 -0
- kiln_ai/adapters/vector_store/test_vector_store_registry.py +199 -0
- kiln_ai/adapters/vector_store/vector_store_registry.py +33 -0
- kiln_ai/datamodel/__init__.py +16 -13
- kiln_ai/datamodel/basemodel.py +201 -4
- kiln_ai/datamodel/chunk.py +158 -0
- kiln_ai/datamodel/datamodel_enums.py +27 -0
- kiln_ai/datamodel/embedding.py +64 -0
- kiln_ai/datamodel/external_tool_server.py +206 -54
- kiln_ai/datamodel/extraction.py +317 -0
- kiln_ai/datamodel/project.py +33 -1
- kiln_ai/datamodel/rag.py +79 -0
- kiln_ai/datamodel/task.py +5 -0
- kiln_ai/datamodel/task_output.py +41 -11
- kiln_ai/datamodel/test_attachment.py +649 -0
- kiln_ai/datamodel/test_basemodel.py +270 -14
- kiln_ai/datamodel/test_chunk_models.py +317 -0
- kiln_ai/datamodel/test_dataset_split.py +1 -1
- kiln_ai/datamodel/test_datasource.py +50 -0
- kiln_ai/datamodel/test_embedding_models.py +448 -0
- kiln_ai/datamodel/test_eval_model.py +6 -6
- kiln_ai/datamodel/test_external_tool_server.py +534 -152
- kiln_ai/datamodel/test_extraction_chunk.py +206 -0
- kiln_ai/datamodel/test_extraction_model.py +501 -0
- kiln_ai/datamodel/test_rag.py +641 -0
- kiln_ai/datamodel/test_task.py +35 -1
- kiln_ai/datamodel/test_tool_id.py +187 -1
- kiln_ai/datamodel/test_vector_store.py +320 -0
- kiln_ai/datamodel/tool_id.py +58 -0
- kiln_ai/datamodel/vector_store.py +141 -0
- kiln_ai/tools/base_tool.py +12 -3
- kiln_ai/tools/built_in_tools/math_tools.py +12 -4
- kiln_ai/tools/kiln_task_tool.py +158 -0
- kiln_ai/tools/mcp_server_tool.py +2 -2
- kiln_ai/tools/mcp_session_manager.py +51 -22
- kiln_ai/tools/rag_tools.py +164 -0
- kiln_ai/tools/test_kiln_task_tool.py +527 -0
- kiln_ai/tools/test_mcp_server_tool.py +4 -15
- kiln_ai/tools/test_mcp_session_manager.py +187 -227
- kiln_ai/tools/test_rag_tools.py +929 -0
- kiln_ai/tools/test_tool_registry.py +290 -7
- kiln_ai/tools/tool_registry.py +69 -16
- kiln_ai/utils/__init__.py +3 -0
- kiln_ai/utils/async_job_runner.py +62 -17
- kiln_ai/utils/config.py +2 -2
- kiln_ai/utils/env.py +15 -0
- kiln_ai/utils/filesystem.py +14 -0
- kiln_ai/utils/filesystem_cache.py +60 -0
- kiln_ai/utils/litellm.py +94 -0
- kiln_ai/utils/lock.py +100 -0
- kiln_ai/utils/mime_type.py +38 -0
- kiln_ai/utils/open_ai_types.py +19 -2
- kiln_ai/utils/pdf_utils.py +59 -0
- kiln_ai/utils/test_async_job_runner.py +151 -35
- kiln_ai/utils/test_env.py +142 -0
- kiln_ai/utils/test_filesystem_cache.py +316 -0
- kiln_ai/utils/test_litellm.py +206 -0
- kiln_ai/utils/test_lock.py +185 -0
- kiln_ai/utils/test_mime_type.py +66 -0
- kiln_ai/utils/test_open_ai_types.py +88 -12
- kiln_ai/utils/test_pdf_utils.py +86 -0
- kiln_ai/utils/test_uuid.py +111 -0
- kiln_ai/utils/test_validation.py +524 -0
- kiln_ai/utils/uuid.py +9 -0
- kiln_ai/utils/validation.py +90 -0
- {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/METADATA +9 -1
- kiln_ai-0.22.0.dist-info/RECORD +213 -0
- kiln_ai-0.20.1.dist-info/RECORD +0 -138
- {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -0,0 +1,929 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
from unittest.mock import AsyncMock, Mock, patch
|
|
3
|
+
|
|
4
|
+
import pytest
|
|
5
|
+
|
|
6
|
+
from kiln_ai.adapters.vector_store.base_vector_store_adapter import SearchResult
|
|
7
|
+
from kiln_ai.datamodel.embedding import EmbeddingConfig
|
|
8
|
+
from kiln_ai.datamodel.project import Project
|
|
9
|
+
from kiln_ai.datamodel.rag import RagConfig
|
|
10
|
+
from kiln_ai.datamodel.vector_store import VectorStoreConfig, VectorStoreType
|
|
11
|
+
from kiln_ai.tools.base_tool import ToolCallContext
|
|
12
|
+
from kiln_ai.tools.rag_tools import ChunkContext, RagTool, format_search_results
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class TestChunkContext:
|
|
16
|
+
"""Test the ChunkContext model."""
|
|
17
|
+
|
|
18
|
+
def test_chunk_context_serialize_basic(self):
|
|
19
|
+
"""Test basic serialization of ChunkContext."""
|
|
20
|
+
chunk = ChunkContext(
|
|
21
|
+
metadata={"document_id": "doc1", "chunk_idx": 0},
|
|
22
|
+
text="This is test content.",
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
result = chunk.serialize()
|
|
26
|
+
expected = "[document_id: doc1, chunk_idx: 0]\nThis is test content.\n\n"
|
|
27
|
+
assert result == expected
|
|
28
|
+
|
|
29
|
+
def test_chunk_context_serialize_empty_metadata(self):
|
|
30
|
+
"""Test serialization with empty metadata."""
|
|
31
|
+
chunk = ChunkContext(metadata={}, text="Content without metadata.")
|
|
32
|
+
|
|
33
|
+
result = chunk.serialize()
|
|
34
|
+
expected = "[]\nContent without metadata.\n\n"
|
|
35
|
+
assert result == expected
|
|
36
|
+
|
|
37
|
+
def test_chunk_context_serialize_multiple_metadata(self):
|
|
38
|
+
"""Test serialization with multiple metadata fields."""
|
|
39
|
+
chunk = ChunkContext(
|
|
40
|
+
metadata={
|
|
41
|
+
"document_id": "doc123",
|
|
42
|
+
"chunk_idx": 5,
|
|
43
|
+
"score": 0.95,
|
|
44
|
+
"source": "file.txt",
|
|
45
|
+
},
|
|
46
|
+
text="Multi-metadata content.",
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
result = chunk.serialize()
|
|
50
|
+
# Note: dict order might vary, so check that all parts are present
|
|
51
|
+
assert "[" in result and "]" in result
|
|
52
|
+
assert "document_id: doc123" in result
|
|
53
|
+
assert "chunk_idx: 5" in result
|
|
54
|
+
assert "score: 0.95" in result
|
|
55
|
+
assert "source: file.txt" in result
|
|
56
|
+
assert "\nMulti-metadata content.\n\n" in result
|
|
57
|
+
|
|
58
|
+
def test_chunk_context_serialize_empty_text(self):
|
|
59
|
+
"""Test serialization with empty text."""
|
|
60
|
+
chunk = ChunkContext(metadata={"document_id": "doc1"}, text="")
|
|
61
|
+
|
|
62
|
+
result = chunk.serialize()
|
|
63
|
+
expected = "[document_id: doc1]\n\n\n"
|
|
64
|
+
assert result == expected
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class TestFormatSearchResults:
|
|
68
|
+
"""Test the format_search_results function."""
|
|
69
|
+
|
|
70
|
+
def test_format_search_results_single_result(self):
|
|
71
|
+
"""Test formatting a single search result."""
|
|
72
|
+
search_results = [
|
|
73
|
+
SearchResult(
|
|
74
|
+
document_id="doc1",
|
|
75
|
+
chunk_idx=0,
|
|
76
|
+
chunk_text="First chunk content",
|
|
77
|
+
similarity=0.95,
|
|
78
|
+
)
|
|
79
|
+
]
|
|
80
|
+
|
|
81
|
+
result = format_search_results(search_results)
|
|
82
|
+
expected = "[document_id: doc1, chunk_idx: 0]\nFirst chunk content\n\n"
|
|
83
|
+
assert result == expected
|
|
84
|
+
|
|
85
|
+
def test_format_search_results_multiple_results(self):
|
|
86
|
+
"""Test formatting multiple search results."""
|
|
87
|
+
search_results = [
|
|
88
|
+
SearchResult(
|
|
89
|
+
document_id="doc1",
|
|
90
|
+
chunk_idx=0,
|
|
91
|
+
chunk_text="First chunk",
|
|
92
|
+
similarity=0.95,
|
|
93
|
+
),
|
|
94
|
+
SearchResult(
|
|
95
|
+
document_id="doc2",
|
|
96
|
+
chunk_idx=1,
|
|
97
|
+
chunk_text="Second chunk",
|
|
98
|
+
similarity=0.85,
|
|
99
|
+
),
|
|
100
|
+
]
|
|
101
|
+
|
|
102
|
+
result = format_search_results(search_results)
|
|
103
|
+
|
|
104
|
+
# Check that both chunks are present and separated by the delimiter
|
|
105
|
+
assert "[document_id: doc1, chunk_idx: 0]\nFirst chunk\n\n" in result
|
|
106
|
+
assert "[document_id: doc2, chunk_idx: 1]\nSecond chunk\n\n" in result
|
|
107
|
+
assert "\n=========\n" in result
|
|
108
|
+
|
|
109
|
+
def test_format_search_results_empty_list(self):
|
|
110
|
+
"""Test formatting empty search results."""
|
|
111
|
+
search_results: List[SearchResult] = []
|
|
112
|
+
|
|
113
|
+
result = format_search_results(search_results)
|
|
114
|
+
assert result == ""
|
|
115
|
+
|
|
116
|
+
def test_format_search_results_preserves_search_result_data(self):
|
|
117
|
+
"""Test that formatting preserves all relevant SearchResult data."""
|
|
118
|
+
search_results = [
|
|
119
|
+
SearchResult(
|
|
120
|
+
document_id="test_doc_123",
|
|
121
|
+
chunk_idx=42,
|
|
122
|
+
chunk_text="Complex text with\nmultiple lines\nand special chars!@#$%",
|
|
123
|
+
similarity=0.7654321,
|
|
124
|
+
)
|
|
125
|
+
]
|
|
126
|
+
|
|
127
|
+
result = format_search_results(search_results)
|
|
128
|
+
|
|
129
|
+
assert "document_id: test_doc_123" in result
|
|
130
|
+
assert "chunk_idx: 42" in result
|
|
131
|
+
assert "Complex text with\nmultiple lines\nand special chars!@#$%" in result
|
|
132
|
+
# Note: similarity is not included in the formatted output, which matches the implementation
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
class TestRagTool:
|
|
136
|
+
"""Test the RagTool class."""
|
|
137
|
+
|
|
138
|
+
@pytest.fixture
|
|
139
|
+
def mock_rag_config(self):
|
|
140
|
+
"""Create a mock RAG config."""
|
|
141
|
+
config = Mock(spec=RagConfig)
|
|
142
|
+
config.id = "rag_config_123"
|
|
143
|
+
config.tool_name = "Test Search Tool"
|
|
144
|
+
config.tool_description = "A test search tool for RAG"
|
|
145
|
+
config.vector_store_config_id = "vector_store_456"
|
|
146
|
+
config.embedding_config_id = "embedding_789"
|
|
147
|
+
return config
|
|
148
|
+
|
|
149
|
+
@pytest.fixture
|
|
150
|
+
def mock_project(self):
|
|
151
|
+
"""Create a mock project."""
|
|
152
|
+
project = Mock(spec=Project)
|
|
153
|
+
project.id = "project_123"
|
|
154
|
+
project.path = "/test/project/path"
|
|
155
|
+
return project
|
|
156
|
+
|
|
157
|
+
@pytest.fixture
|
|
158
|
+
def mock_vector_store_config(self):
|
|
159
|
+
"""Create a mock vector store config."""
|
|
160
|
+
config = Mock(spec=VectorStoreConfig)
|
|
161
|
+
config.id = "vector_store_456"
|
|
162
|
+
config.store_type = VectorStoreType.LANCE_DB_VECTOR
|
|
163
|
+
return config
|
|
164
|
+
|
|
165
|
+
@pytest.fixture
|
|
166
|
+
def mock_embedding_config(self):
|
|
167
|
+
"""Create a mock embedding config."""
|
|
168
|
+
config = Mock(spec=EmbeddingConfig)
|
|
169
|
+
config.id = "embedding_789"
|
|
170
|
+
return config
|
|
171
|
+
|
|
172
|
+
def test_rag_tool_init_success(self, mock_rag_config, mock_project):
|
|
173
|
+
"""Test successful RagTool initialization."""
|
|
174
|
+
mock_rag_config.parent_project.return_value = mock_project
|
|
175
|
+
|
|
176
|
+
with patch("kiln_ai.tools.rag_tools.VectorStoreConfig") as mock_vs_config_class:
|
|
177
|
+
mock_vector_store_config = Mock(spec=VectorStoreConfig)
|
|
178
|
+
mock_vs_config_class.from_id_and_parent_path.return_value = (
|
|
179
|
+
mock_vector_store_config
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
tool = RagTool("tool_123", mock_rag_config)
|
|
183
|
+
|
|
184
|
+
assert tool._id == "tool_123"
|
|
185
|
+
assert tool._name == "Test Search Tool"
|
|
186
|
+
assert tool._description == "A test search tool for RAG"
|
|
187
|
+
assert tool._rag_config == mock_rag_config
|
|
188
|
+
assert tool._vector_store_config == mock_vector_store_config
|
|
189
|
+
assert tool._vector_store_adapter is None
|
|
190
|
+
|
|
191
|
+
# Verify vector store config lookup
|
|
192
|
+
mock_vs_config_class.from_id_and_parent_path.assert_called_once_with(
|
|
193
|
+
"vector_store_456", "/test/project/path"
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
def test_rag_tool_init_vector_store_config_not_found(
|
|
197
|
+
self, mock_rag_config, mock_project
|
|
198
|
+
):
|
|
199
|
+
"""Test RagTool initialization when vector store config is not found."""
|
|
200
|
+
mock_rag_config.parent_project.return_value = mock_project
|
|
201
|
+
|
|
202
|
+
with patch("kiln_ai.tools.rag_tools.VectorStoreConfig") as mock_vs_config_class:
|
|
203
|
+
mock_vs_config_class.from_id_and_parent_path.return_value = None
|
|
204
|
+
|
|
205
|
+
with pytest.raises(
|
|
206
|
+
ValueError, match="Vector store config not found: vector_store_456"
|
|
207
|
+
):
|
|
208
|
+
RagTool("tool_123", mock_rag_config)
|
|
209
|
+
|
|
210
|
+
def test_rag_tool_project_property(self, mock_rag_config, mock_project):
|
|
211
|
+
"""Test RagTool project cached property."""
|
|
212
|
+
mock_rag_config.parent_project.return_value = mock_project
|
|
213
|
+
|
|
214
|
+
with patch("kiln_ai.tools.rag_tools.VectorStoreConfig") as mock_vs_config_class:
|
|
215
|
+
mock_vs_config_class.from_id_and_parent_path.return_value = Mock()
|
|
216
|
+
|
|
217
|
+
tool = RagTool("tool_123", mock_rag_config)
|
|
218
|
+
|
|
219
|
+
# Test that project property returns the correct project
|
|
220
|
+
assert tool.project == mock_project
|
|
221
|
+
|
|
222
|
+
# Test that it's cached (should not call parent_project again)
|
|
223
|
+
assert tool.project == mock_project
|
|
224
|
+
mock_rag_config.parent_project.assert_called_once()
|
|
225
|
+
|
|
226
|
+
def test_rag_tool_project_property_no_project(self, mock_rag_config):
|
|
227
|
+
"""Test RagTool initialization when no project is found."""
|
|
228
|
+
mock_rag_config.parent_project.return_value = None
|
|
229
|
+
|
|
230
|
+
with patch("kiln_ai.tools.rag_tools.VectorStoreConfig") as mock_vs_config_class:
|
|
231
|
+
mock_vs_config_class.from_id_and_parent_path.return_value = Mock()
|
|
232
|
+
|
|
233
|
+
# The constructor should fail when accessing the project property
|
|
234
|
+
with pytest.raises(
|
|
235
|
+
ValueError, match="RAG config rag_config_123 has no project"
|
|
236
|
+
):
|
|
237
|
+
RagTool("tool_123", mock_rag_config)
|
|
238
|
+
|
|
239
|
+
def test_rag_tool_embedding_property(
|
|
240
|
+
self, mock_rag_config, mock_project, mock_embedding_config
|
|
241
|
+
):
|
|
242
|
+
"""Test RagTool embedding cached property."""
|
|
243
|
+
mock_rag_config.parent_project.return_value = mock_project
|
|
244
|
+
mock_embedding_adapter = Mock()
|
|
245
|
+
|
|
246
|
+
with (
|
|
247
|
+
patch("kiln_ai.tools.rag_tools.VectorStoreConfig") as mock_vs_config_class,
|
|
248
|
+
patch("kiln_ai.tools.rag_tools.EmbeddingConfig") as mock_embed_config_class,
|
|
249
|
+
patch(
|
|
250
|
+
"kiln_ai.tools.rag_tools.embedding_adapter_from_type"
|
|
251
|
+
) as mock_adapter_factory,
|
|
252
|
+
):
|
|
253
|
+
mock_vs_config_class.from_id_and_parent_path.return_value = Mock()
|
|
254
|
+
mock_embed_config_class.from_id_and_parent_path.return_value = (
|
|
255
|
+
mock_embedding_config
|
|
256
|
+
)
|
|
257
|
+
mock_adapter_factory.return_value = mock_embedding_adapter
|
|
258
|
+
|
|
259
|
+
tool = RagTool("tool_123", mock_rag_config)
|
|
260
|
+
|
|
261
|
+
# Test that embedding property returns the correct tuple
|
|
262
|
+
config, adapter = tool.embedding
|
|
263
|
+
assert config == mock_embedding_config
|
|
264
|
+
assert adapter == mock_embedding_adapter
|
|
265
|
+
|
|
266
|
+
# Test that it's cached
|
|
267
|
+
config2, adapter2 = tool.embedding
|
|
268
|
+
assert config2 == mock_embedding_config
|
|
269
|
+
assert adapter2 == mock_embedding_adapter
|
|
270
|
+
|
|
271
|
+
# Verify calls
|
|
272
|
+
mock_embed_config_class.from_id_and_parent_path.assert_called_once_with(
|
|
273
|
+
"embedding_789", "/test/project/path"
|
|
274
|
+
)
|
|
275
|
+
mock_adapter_factory.assert_called_once_with(mock_embedding_config)
|
|
276
|
+
|
|
277
|
+
def test_rag_tool_embedding_property_config_not_found(
|
|
278
|
+
self, mock_rag_config, mock_project
|
|
279
|
+
):
|
|
280
|
+
"""Test RagTool embedding property when embedding config is not found."""
|
|
281
|
+
mock_rag_config.parent_project.return_value = mock_project
|
|
282
|
+
|
|
283
|
+
with (
|
|
284
|
+
patch("kiln_ai.tools.rag_tools.VectorStoreConfig") as mock_vs_config_class,
|
|
285
|
+
patch("kiln_ai.tools.rag_tools.EmbeddingConfig") as mock_embed_config_class,
|
|
286
|
+
):
|
|
287
|
+
mock_vs_config_class.from_id_and_parent_path.return_value = Mock()
|
|
288
|
+
mock_embed_config_class.from_id_and_parent_path.return_value = None
|
|
289
|
+
|
|
290
|
+
tool = RagTool("tool_123", mock_rag_config)
|
|
291
|
+
|
|
292
|
+
with pytest.raises(
|
|
293
|
+
ValueError, match="Embedding config not found: embedding_789"
|
|
294
|
+
):
|
|
295
|
+
_ = tool.embedding
|
|
296
|
+
|
|
297
|
+
async def test_rag_tool_vector_store_property(self, mock_rag_config, mock_project):
|
|
298
|
+
"""Test RagTool vector_store async property."""
|
|
299
|
+
mock_rag_config.parent_project.return_value = mock_project
|
|
300
|
+
mock_vector_store_adapter = AsyncMock()
|
|
301
|
+
|
|
302
|
+
with (
|
|
303
|
+
patch("kiln_ai.tools.rag_tools.VectorStoreConfig") as mock_vs_config_class,
|
|
304
|
+
patch(
|
|
305
|
+
"kiln_ai.tools.rag_tools.vector_store_adapter_for_config",
|
|
306
|
+
new_callable=AsyncMock,
|
|
307
|
+
) as mock_adapter_factory,
|
|
308
|
+
):
|
|
309
|
+
mock_vector_store_config = Mock()
|
|
310
|
+
mock_vs_config_class.from_id_and_parent_path.return_value = (
|
|
311
|
+
mock_vector_store_config
|
|
312
|
+
)
|
|
313
|
+
mock_adapter_factory.return_value = mock_vector_store_adapter
|
|
314
|
+
|
|
315
|
+
tool = RagTool("tool_123", mock_rag_config)
|
|
316
|
+
|
|
317
|
+
# Test that vector_store property returns the correct adapter
|
|
318
|
+
adapter = await tool.vector_store()
|
|
319
|
+
assert adapter == mock_vector_store_adapter
|
|
320
|
+
|
|
321
|
+
# Test that it's cached
|
|
322
|
+
adapter2 = await tool.vector_store()
|
|
323
|
+
assert adapter2 == mock_vector_store_adapter
|
|
324
|
+
|
|
325
|
+
# Verify factory was called only once due to caching
|
|
326
|
+
mock_adapter_factory.assert_called_once_with(
|
|
327
|
+
vector_store_config=mock_vector_store_config, rag_config=mock_rag_config
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
async def test_rag_tool_interface_methods(self, mock_rag_config, mock_project):
|
|
331
|
+
"""Test RagTool interface methods: id, name, description, toolcall_definition."""
|
|
332
|
+
mock_rag_config.parent_project.return_value = mock_project
|
|
333
|
+
|
|
334
|
+
with patch("kiln_ai.tools.rag_tools.VectorStoreConfig") as mock_vs_config_class:
|
|
335
|
+
mock_vs_config_class.from_id_and_parent_path.return_value = Mock()
|
|
336
|
+
|
|
337
|
+
tool = RagTool("tool_123", mock_rag_config)
|
|
338
|
+
|
|
339
|
+
# Test interface methods
|
|
340
|
+
assert await tool.id() == "tool_123"
|
|
341
|
+
assert await tool.name() == "Test Search Tool"
|
|
342
|
+
description = await tool.description()
|
|
343
|
+
assert description == "A test search tool for RAG"
|
|
344
|
+
|
|
345
|
+
# Test toolcall_definition
|
|
346
|
+
definition = await tool.toolcall_definition()
|
|
347
|
+
expected_definition = {
|
|
348
|
+
"type": "function",
|
|
349
|
+
"function": {
|
|
350
|
+
"name": "Test Search Tool",
|
|
351
|
+
"description": "A test search tool for RAG",
|
|
352
|
+
"parameters": {
|
|
353
|
+
"type": "object",
|
|
354
|
+
"properties": {
|
|
355
|
+
"query": {
|
|
356
|
+
"type": "string",
|
|
357
|
+
"description": "The search query",
|
|
358
|
+
},
|
|
359
|
+
},
|
|
360
|
+
"required": ["query"],
|
|
361
|
+
},
|
|
362
|
+
},
|
|
363
|
+
}
|
|
364
|
+
assert definition == expected_definition
|
|
365
|
+
|
|
366
|
+
async def test_rag_tool_run_vector_store_type(self, mock_rag_config, mock_project):
|
|
367
|
+
"""Test RagTool.run() with LANCE_DB_VECTOR store type (embedding needed)."""
|
|
368
|
+
mock_rag_config.parent_project.return_value = mock_project
|
|
369
|
+
|
|
370
|
+
# Mock search results
|
|
371
|
+
search_results = [
|
|
372
|
+
SearchResult(
|
|
373
|
+
document_id="doc1",
|
|
374
|
+
chunk_idx=0,
|
|
375
|
+
chunk_text="Test content 1",
|
|
376
|
+
similarity=0.95,
|
|
377
|
+
),
|
|
378
|
+
SearchResult(
|
|
379
|
+
document_id="doc2",
|
|
380
|
+
chunk_idx=1,
|
|
381
|
+
chunk_text="Test content 2",
|
|
382
|
+
similarity=0.85,
|
|
383
|
+
),
|
|
384
|
+
]
|
|
385
|
+
|
|
386
|
+
with (
|
|
387
|
+
patch("kiln_ai.tools.rag_tools.VectorStoreConfig") as mock_vs_config_class,
|
|
388
|
+
patch("kiln_ai.tools.rag_tools.EmbeddingConfig") as mock_embed_config_class,
|
|
389
|
+
patch(
|
|
390
|
+
"kiln_ai.tools.rag_tools.embedding_adapter_from_type"
|
|
391
|
+
) as mock_adapter_factory,
|
|
392
|
+
patch(
|
|
393
|
+
"kiln_ai.tools.rag_tools.vector_store_adapter_for_config",
|
|
394
|
+
new_callable=AsyncMock,
|
|
395
|
+
) as mock_vs_adapter_factory,
|
|
396
|
+
):
|
|
397
|
+
# Setup mocks
|
|
398
|
+
mock_vector_store_config = Mock()
|
|
399
|
+
mock_vector_store_config.store_type = VectorStoreType.LANCE_DB_VECTOR
|
|
400
|
+
mock_vs_config_class.from_id_and_parent_path.return_value = (
|
|
401
|
+
mock_vector_store_config
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
mock_embedding_config = Mock()
|
|
405
|
+
mock_embed_config_class.from_id_and_parent_path.return_value = (
|
|
406
|
+
mock_embedding_config
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
mock_embedding_adapter = AsyncMock()
|
|
410
|
+
mock_embedding_result = Mock()
|
|
411
|
+
mock_embedding_result.embeddings = [Mock(vector=[0.1, 0.2, 0.3, 0.4])]
|
|
412
|
+
mock_embedding_adapter.generate_embeddings.return_value = (
|
|
413
|
+
mock_embedding_result
|
|
414
|
+
)
|
|
415
|
+
mock_adapter_factory.return_value = mock_embedding_adapter
|
|
416
|
+
|
|
417
|
+
mock_vector_store_adapter = AsyncMock()
|
|
418
|
+
mock_vector_store_adapter.search.return_value = search_results
|
|
419
|
+
mock_vs_adapter_factory.return_value = mock_vector_store_adapter
|
|
420
|
+
|
|
421
|
+
tool = RagTool("tool_123", mock_rag_config)
|
|
422
|
+
|
|
423
|
+
# Run the tool
|
|
424
|
+
result = await tool.run(context=None, query="test query")
|
|
425
|
+
|
|
426
|
+
# Verify the result format
|
|
427
|
+
expected_result = (
|
|
428
|
+
"[document_id: doc1, chunk_idx: 0]\nTest content 1\n\n"
|
|
429
|
+
"\n=========\n"
|
|
430
|
+
"[document_id: doc2, chunk_idx: 1]\nTest content 2\n\n"
|
|
431
|
+
)
|
|
432
|
+
assert result == expected_result
|
|
433
|
+
|
|
434
|
+
# Verify embedding generation was called
|
|
435
|
+
mock_embedding_adapter.generate_embeddings.assert_called_once_with(
|
|
436
|
+
["test query"]
|
|
437
|
+
)
|
|
438
|
+
|
|
439
|
+
# Verify vector store search was called correctly
|
|
440
|
+
mock_vector_store_adapter.search.assert_called_once()
|
|
441
|
+
search_query = mock_vector_store_adapter.search.call_args[0][0]
|
|
442
|
+
assert search_query.query_string == "test query"
|
|
443
|
+
assert search_query.query_embedding == [
|
|
444
|
+
0.1,
|
|
445
|
+
0.2,
|
|
446
|
+
0.3,
|
|
447
|
+
0.4,
|
|
448
|
+
] # Embedding provided for VECTOR type
|
|
449
|
+
|
|
450
|
+
async def test_rag_tool_run_hybrid_store_type(self, mock_rag_config, mock_project):
|
|
451
|
+
"""Test RagTool.run() with LANCE_DB_HYBRID store type (embedding needed)."""
|
|
452
|
+
mock_rag_config.parent_project.return_value = mock_project
|
|
453
|
+
|
|
454
|
+
# Mock embedding result
|
|
455
|
+
mock_embedding_result = Mock()
|
|
456
|
+
mock_embedding_result.embeddings = [Mock(vector=[0.1, 0.2, 0.3, 0.4])]
|
|
457
|
+
|
|
458
|
+
# Mock search results
|
|
459
|
+
search_results = [
|
|
460
|
+
SearchResult(
|
|
461
|
+
document_id="doc1",
|
|
462
|
+
chunk_idx=0,
|
|
463
|
+
chunk_text="Hybrid search result",
|
|
464
|
+
similarity=0.92,
|
|
465
|
+
)
|
|
466
|
+
]
|
|
467
|
+
|
|
468
|
+
with (
|
|
469
|
+
patch("kiln_ai.tools.rag_tools.VectorStoreConfig") as mock_vs_config_class,
|
|
470
|
+
patch("kiln_ai.tools.rag_tools.EmbeddingConfig") as mock_embed_config_class,
|
|
471
|
+
patch(
|
|
472
|
+
"kiln_ai.tools.rag_tools.embedding_adapter_from_type"
|
|
473
|
+
) as mock_adapter_factory,
|
|
474
|
+
patch(
|
|
475
|
+
"kiln_ai.tools.rag_tools.vector_store_adapter_for_config",
|
|
476
|
+
new_callable=AsyncMock,
|
|
477
|
+
) as mock_vs_adapter_factory,
|
|
478
|
+
):
|
|
479
|
+
# Setup mocks
|
|
480
|
+
mock_vector_store_config = Mock()
|
|
481
|
+
mock_vector_store_config.store_type = VectorStoreType.LANCE_DB_HYBRID
|
|
482
|
+
mock_vs_config_class.from_id_and_parent_path.return_value = (
|
|
483
|
+
mock_vector_store_config
|
|
484
|
+
)
|
|
485
|
+
|
|
486
|
+
mock_embedding_config = Mock()
|
|
487
|
+
mock_embed_config_class.from_id_and_parent_path.return_value = (
|
|
488
|
+
mock_embedding_config
|
|
489
|
+
)
|
|
490
|
+
|
|
491
|
+
mock_embedding_adapter = AsyncMock()
|
|
492
|
+
mock_embedding_adapter.generate_embeddings.return_value = (
|
|
493
|
+
mock_embedding_result
|
|
494
|
+
)
|
|
495
|
+
mock_adapter_factory.return_value = mock_embedding_adapter
|
|
496
|
+
|
|
497
|
+
mock_vector_store_adapter = AsyncMock()
|
|
498
|
+
mock_vector_store_adapter.search.return_value = search_results
|
|
499
|
+
mock_vs_adapter_factory.return_value = mock_vector_store_adapter
|
|
500
|
+
|
|
501
|
+
tool = RagTool("tool_123", mock_rag_config)
|
|
502
|
+
|
|
503
|
+
# Run the tool
|
|
504
|
+
result = await tool.run(context=None, query="hybrid query")
|
|
505
|
+
|
|
506
|
+
# Verify embedding generation was called
|
|
507
|
+
mock_embedding_adapter.generate_embeddings.assert_called_once_with(
|
|
508
|
+
["hybrid query"]
|
|
509
|
+
)
|
|
510
|
+
|
|
511
|
+
# Verify vector store search was called with embedding
|
|
512
|
+
mock_vector_store_adapter.search.assert_called_once()
|
|
513
|
+
search_query = mock_vector_store_adapter.search.call_args[0][0]
|
|
514
|
+
assert search_query.query_string == "hybrid query"
|
|
515
|
+
assert search_query.query_embedding == [0.1, 0.2, 0.3, 0.4]
|
|
516
|
+
|
|
517
|
+
# Verify result
|
|
518
|
+
expected_result = (
|
|
519
|
+
"[document_id: doc1, chunk_idx: 0]\nHybrid search result\n\n"
|
|
520
|
+
)
|
|
521
|
+
assert result == expected_result
|
|
522
|
+
|
|
523
|
+
async def test_rag_tool_run_fts_store_type(self, mock_rag_config, mock_project):
|
|
524
|
+
"""Test RagTool.run() with LANCE_DB_FTS store type (no embedding needed)."""
|
|
525
|
+
mock_rag_config.parent_project.return_value = mock_project
|
|
526
|
+
|
|
527
|
+
# Mock search results
|
|
528
|
+
search_results = [
|
|
529
|
+
SearchResult(
|
|
530
|
+
document_id="doc_fts",
|
|
531
|
+
chunk_idx=2,
|
|
532
|
+
chunk_text="FTS search result",
|
|
533
|
+
similarity=0.88,
|
|
534
|
+
)
|
|
535
|
+
]
|
|
536
|
+
|
|
537
|
+
with (
|
|
538
|
+
patch("kiln_ai.tools.rag_tools.VectorStoreConfig") as mock_vs_config_class,
|
|
539
|
+
patch("kiln_ai.tools.rag_tools.EmbeddingConfig") as mock_embed_config_class,
|
|
540
|
+
patch(
|
|
541
|
+
"kiln_ai.tools.rag_tools.embedding_adapter_from_type"
|
|
542
|
+
) as mock_adapter_factory,
|
|
543
|
+
patch(
|
|
544
|
+
"kiln_ai.tools.rag_tools.vector_store_adapter_for_config",
|
|
545
|
+
new_callable=AsyncMock,
|
|
546
|
+
) as mock_vs_adapter_factory,
|
|
547
|
+
):
|
|
548
|
+
# Setup mocks
|
|
549
|
+
mock_vector_store_config = Mock()
|
|
550
|
+
mock_vector_store_config.store_type = VectorStoreType.LANCE_DB_FTS
|
|
551
|
+
mock_vs_config_class.from_id_and_parent_path.return_value = (
|
|
552
|
+
mock_vector_store_config
|
|
553
|
+
)
|
|
554
|
+
|
|
555
|
+
mock_embedding_config = Mock()
|
|
556
|
+
mock_embed_config_class.from_id_and_parent_path.return_value = (
|
|
557
|
+
mock_embedding_config
|
|
558
|
+
)
|
|
559
|
+
|
|
560
|
+
mock_embedding_adapter = AsyncMock()
|
|
561
|
+
mock_adapter_factory.return_value = mock_embedding_adapter
|
|
562
|
+
|
|
563
|
+
mock_vector_store_adapter = AsyncMock()
|
|
564
|
+
mock_vector_store_adapter.search.return_value = search_results
|
|
565
|
+
mock_vs_adapter_factory.return_value = mock_vector_store_adapter
|
|
566
|
+
|
|
567
|
+
tool = RagTool("tool_123", mock_rag_config)
|
|
568
|
+
|
|
569
|
+
# Run the tool
|
|
570
|
+
result = await tool.run(context=None, query="fts query")
|
|
571
|
+
|
|
572
|
+
# Verify the result format
|
|
573
|
+
expected_result = (
|
|
574
|
+
"[document_id: doc_fts, chunk_idx: 2]\nFTS search result\n\n"
|
|
575
|
+
)
|
|
576
|
+
assert result == expected_result
|
|
577
|
+
|
|
578
|
+
# Verify embedding generation was NOT called for FTS
|
|
579
|
+
mock_embedding_adapter.generate_embeddings.assert_not_called()
|
|
580
|
+
|
|
581
|
+
# Verify vector store search was called without embedding
|
|
582
|
+
mock_vector_store_adapter.search.assert_called_once()
|
|
583
|
+
search_query = mock_vector_store_adapter.search.call_args[0][0]
|
|
584
|
+
assert search_query.query_string == "fts query"
|
|
585
|
+
assert search_query.query_embedding is None # No embedding for FTS type
|
|
586
|
+
|
|
587
|
+
async def test_rag_tool_run_no_embeddings_generated(
|
|
588
|
+
self, mock_rag_config, mock_project
|
|
589
|
+
):
|
|
590
|
+
"""Test RagTool.run() when no embeddings are generated."""
|
|
591
|
+
mock_rag_config.parent_project.return_value = mock_project
|
|
592
|
+
|
|
593
|
+
# Mock empty embedding result
|
|
594
|
+
mock_embedding_result = Mock()
|
|
595
|
+
mock_embedding_result.embeddings = []
|
|
596
|
+
|
|
597
|
+
with (
|
|
598
|
+
patch("kiln_ai.tools.rag_tools.VectorStoreConfig") as mock_vs_config_class,
|
|
599
|
+
patch("kiln_ai.tools.rag_tools.EmbeddingConfig") as mock_embed_config_class,
|
|
600
|
+
patch(
|
|
601
|
+
"kiln_ai.tools.rag_tools.embedding_adapter_from_type"
|
|
602
|
+
) as mock_adapter_factory,
|
|
603
|
+
patch(
|
|
604
|
+
"kiln_ai.tools.rag_tools.vector_store_adapter_for_config",
|
|
605
|
+
new_callable=AsyncMock,
|
|
606
|
+
) as mock_vs_adapter_factory,
|
|
607
|
+
):
|
|
608
|
+
# Setup mocks
|
|
609
|
+
mock_vector_store_config = Mock()
|
|
610
|
+
mock_vector_store_config.store_type = VectorStoreType.LANCE_DB_HYBRID
|
|
611
|
+
mock_vs_config_class.from_id_and_parent_path.return_value = (
|
|
612
|
+
mock_vector_store_config
|
|
613
|
+
)
|
|
614
|
+
|
|
615
|
+
mock_embedding_config = Mock()
|
|
616
|
+
mock_embed_config_class.from_id_and_parent_path.return_value = (
|
|
617
|
+
mock_embedding_config
|
|
618
|
+
)
|
|
619
|
+
|
|
620
|
+
mock_embedding_adapter = AsyncMock()
|
|
621
|
+
mock_embedding_adapter.generate_embeddings.return_value = (
|
|
622
|
+
mock_embedding_result
|
|
623
|
+
)
|
|
624
|
+
mock_adapter_factory.return_value = mock_embedding_adapter
|
|
625
|
+
|
|
626
|
+
mock_vector_store_adapter = AsyncMock()
|
|
627
|
+
mock_vs_adapter_factory.return_value = mock_vector_store_adapter
|
|
628
|
+
|
|
629
|
+
tool = RagTool("tool_123", mock_rag_config)
|
|
630
|
+
|
|
631
|
+
# Run the tool and expect an error
|
|
632
|
+
with pytest.raises(ValueError, match="No embeddings generated"):
|
|
633
|
+
await tool.run(context=None, query="query with no embeddings")
|
|
634
|
+
|
|
635
|
+
async def test_rag_tool_run_empty_search_results(
|
|
636
|
+
self, mock_rag_config, mock_project
|
|
637
|
+
):
|
|
638
|
+
"""Test RagTool.run() with empty search results."""
|
|
639
|
+
mock_rag_config.parent_project.return_value = mock_project
|
|
640
|
+
|
|
641
|
+
with (
|
|
642
|
+
patch("kiln_ai.tools.rag_tools.VectorStoreConfig") as mock_vs_config_class,
|
|
643
|
+
patch("kiln_ai.tools.rag_tools.EmbeddingConfig") as mock_embed_config_class,
|
|
644
|
+
patch(
|
|
645
|
+
"kiln_ai.tools.rag_tools.embedding_adapter_from_type"
|
|
646
|
+
) as mock_adapter_factory,
|
|
647
|
+
patch(
|
|
648
|
+
"kiln_ai.tools.rag_tools.vector_store_adapter_for_config",
|
|
649
|
+
new_callable=AsyncMock,
|
|
650
|
+
) as mock_vs_adapter_factory,
|
|
651
|
+
):
|
|
652
|
+
# Setup mocks
|
|
653
|
+
mock_vector_store_config = Mock()
|
|
654
|
+
mock_vector_store_config.store_type = VectorStoreType.LANCE_DB_VECTOR
|
|
655
|
+
mock_vs_config_class.from_id_and_parent_path.return_value = (
|
|
656
|
+
mock_vector_store_config
|
|
657
|
+
)
|
|
658
|
+
|
|
659
|
+
mock_embedding_config = Mock()
|
|
660
|
+
mock_embed_config_class.from_id_and_parent_path.return_value = (
|
|
661
|
+
mock_embedding_config
|
|
662
|
+
)
|
|
663
|
+
|
|
664
|
+
mock_embedding_adapter = AsyncMock()
|
|
665
|
+
mock_embedding_result = Mock()
|
|
666
|
+
mock_embedding_result.embeddings = [Mock(vector=[0.1, 0.2, 0.3, 0.4])]
|
|
667
|
+
mock_embedding_adapter.generate_embeddings.return_value = (
|
|
668
|
+
mock_embedding_result
|
|
669
|
+
)
|
|
670
|
+
mock_adapter_factory.return_value = mock_embedding_adapter
|
|
671
|
+
|
|
672
|
+
mock_vector_store_adapter = AsyncMock()
|
|
673
|
+
mock_vector_store_adapter.search.return_value = [] # Empty results
|
|
674
|
+
mock_vs_adapter_factory.return_value = mock_vector_store_adapter
|
|
675
|
+
|
|
676
|
+
tool = RagTool("tool_123", mock_rag_config)
|
|
677
|
+
|
|
678
|
+
# Run the tool
|
|
679
|
+
result = await tool.run(context=None, query="query with no results")
|
|
680
|
+
|
|
681
|
+
# Should return empty string for no results
|
|
682
|
+
assert result == ""
|
|
683
|
+
|
|
684
|
+
async def test_rag_tool_run_with_context_is_accepted(
|
|
685
|
+
self, mock_rag_config, mock_project
|
|
686
|
+
):
|
|
687
|
+
"""Ensure RagTool.run accepts and works when a ToolCallContext is provided."""
|
|
688
|
+
mock_rag_config.parent_project.return_value = mock_project
|
|
689
|
+
|
|
690
|
+
# Mock search results
|
|
691
|
+
search_results = [
|
|
692
|
+
SearchResult(
|
|
693
|
+
document_id="doc_ctx",
|
|
694
|
+
chunk_idx=3,
|
|
695
|
+
chunk_text="Context ok",
|
|
696
|
+
similarity=0.77,
|
|
697
|
+
)
|
|
698
|
+
]
|
|
699
|
+
|
|
700
|
+
with (
|
|
701
|
+
patch("kiln_ai.tools.rag_tools.VectorStoreConfig") as mock_vs_config_class,
|
|
702
|
+
patch("kiln_ai.tools.rag_tools.EmbeddingConfig") as mock_embed_config_class,
|
|
703
|
+
patch(
|
|
704
|
+
"kiln_ai.tools.rag_tools.embedding_adapter_from_type"
|
|
705
|
+
) as mock_adapter_factory,
|
|
706
|
+
patch(
|
|
707
|
+
"kiln_ai.tools.rag_tools.vector_store_adapter_for_config",
|
|
708
|
+
new_callable=AsyncMock,
|
|
709
|
+
) as mock_vs_adapter_factory,
|
|
710
|
+
):
|
|
711
|
+
# VECTOR type → embedding path taken
|
|
712
|
+
mock_vector_store_config = Mock()
|
|
713
|
+
mock_vector_store_config.store_type = VectorStoreType.LANCE_DB_VECTOR
|
|
714
|
+
mock_vs_config_class.from_id_and_parent_path.return_value = (
|
|
715
|
+
mock_vector_store_config
|
|
716
|
+
)
|
|
717
|
+
|
|
718
|
+
mock_embedding_config = Mock()
|
|
719
|
+
mock_embed_config_class.from_id_and_parent_path.return_value = (
|
|
720
|
+
mock_embedding_config
|
|
721
|
+
)
|
|
722
|
+
|
|
723
|
+
mock_embedding_adapter = AsyncMock()
|
|
724
|
+
mock_embedding_result = Mock()
|
|
725
|
+
mock_embedding_result.embeddings = [Mock(vector=[1.0])]
|
|
726
|
+
mock_embedding_adapter.generate_embeddings.return_value = (
|
|
727
|
+
mock_embedding_result
|
|
728
|
+
)
|
|
729
|
+
mock_adapter_factory.return_value = mock_embedding_adapter
|
|
730
|
+
|
|
731
|
+
mock_vector_store_adapter = AsyncMock()
|
|
732
|
+
mock_vector_store_adapter.search.return_value = search_results
|
|
733
|
+
mock_vs_adapter_factory.return_value = mock_vector_store_adapter
|
|
734
|
+
|
|
735
|
+
tool = RagTool("tool_ctx", mock_rag_config)
|
|
736
|
+
|
|
737
|
+
ctx = ToolCallContext(allow_saving=False)
|
|
738
|
+
result = await tool.run(context=ctx, query="with context")
|
|
739
|
+
|
|
740
|
+
# Works and returns formatted text
|
|
741
|
+
assert result == "[document_id: doc_ctx, chunk_idx: 3]\nContext ok\n\n"
|
|
742
|
+
|
|
743
|
+
# Normal behavior still occurs
|
|
744
|
+
mock_embedding_adapter.generate_embeddings.assert_called_once_with(
|
|
745
|
+
["with context"]
|
|
746
|
+
)
|
|
747
|
+
mock_vector_store_adapter.search.assert_called_once()
|
|
748
|
+
|
|
749
|
+
async def test_rag_tool_run_missing_query_raises(
|
|
750
|
+
self, mock_rag_config, mock_project
|
|
751
|
+
):
|
|
752
|
+
"""Ensure RagTool.run enforces the 'if not query' guard."""
|
|
753
|
+
mock_rag_config.parent_project.return_value = mock_project
|
|
754
|
+
|
|
755
|
+
with (
|
|
756
|
+
patch("kiln_ai.tools.rag_tools.VectorStoreConfig") as mock_vs_config_class,
|
|
757
|
+
):
|
|
758
|
+
mock_vs_config_class.from_id_and_parent_path.return_value = Mock()
|
|
759
|
+
tool = RagTool("tool_err", mock_rag_config)
|
|
760
|
+
|
|
761
|
+
with pytest.raises(KeyError, match="query"):
|
|
762
|
+
await tool.run(context=None)
|
|
763
|
+
|
|
764
|
+
|
|
765
|
+
class TestRagToolNameAndDescription:
|
|
766
|
+
"""Test RagTool name and description functionality with tool_name and tool_description fields."""
|
|
767
|
+
|
|
768
|
+
@pytest.fixture
|
|
769
|
+
def mock_rag_config_with_tool_fields(self):
|
|
770
|
+
"""Create a mock RAG config with specific tool_name and tool_description."""
|
|
771
|
+
config = Mock(spec=RagConfig)
|
|
772
|
+
config.id = "rag_config_456"
|
|
773
|
+
config.tool_name = "Advanced Document Search"
|
|
774
|
+
config.tool_description = "An advanced search tool that retrieves relevant documents from the knowledge base using semantic similarity"
|
|
775
|
+
config.vector_store_config_id = "vector_store_789"
|
|
776
|
+
config.embedding_config_id = "embedding_101"
|
|
777
|
+
return config
|
|
778
|
+
|
|
779
|
+
@pytest.fixture
|
|
780
|
+
def mock_project_for_tool_fields(self):
|
|
781
|
+
"""Create a mock project for tool field tests."""
|
|
782
|
+
project = Mock(spec=Project)
|
|
783
|
+
project.id = "project_456"
|
|
784
|
+
project.path = "/test/tool/project"
|
|
785
|
+
return project
|
|
786
|
+
|
|
787
|
+
def test_rag_tool_uses_tool_name_field(
|
|
788
|
+
self, mock_rag_config_with_tool_fields, mock_project_for_tool_fields
|
|
789
|
+
):
|
|
790
|
+
"""Test that RagTool uses the tool_name field from RagConfig."""
|
|
791
|
+
mock_rag_config_with_tool_fields.parent_project.return_value = (
|
|
792
|
+
mock_project_for_tool_fields
|
|
793
|
+
)
|
|
794
|
+
|
|
795
|
+
with patch("kiln_ai.tools.rag_tools.VectorStoreConfig") as mock_vs_config_class:
|
|
796
|
+
mock_vs_config_class.from_id_and_parent_path.return_value = Mock()
|
|
797
|
+
|
|
798
|
+
tool = RagTool("tool_456", mock_rag_config_with_tool_fields)
|
|
799
|
+
|
|
800
|
+
assert tool._name == "Advanced Document Search"
|
|
801
|
+
|
|
802
|
+
def test_rag_tool_uses_tool_description_field(
|
|
803
|
+
self, mock_rag_config_with_tool_fields, mock_project_for_tool_fields
|
|
804
|
+
):
|
|
805
|
+
"""Test that RagTool uses the tool_description field from RagConfig."""
|
|
806
|
+
mock_rag_config_with_tool_fields.parent_project.return_value = (
|
|
807
|
+
mock_project_for_tool_fields
|
|
808
|
+
)
|
|
809
|
+
|
|
810
|
+
with patch("kiln_ai.tools.rag_tools.VectorStoreConfig") as mock_vs_config_class:
|
|
811
|
+
mock_vs_config_class.from_id_and_parent_path.return_value = Mock()
|
|
812
|
+
|
|
813
|
+
tool = RagTool("tool_456", mock_rag_config_with_tool_fields)
|
|
814
|
+
|
|
815
|
+
assert (
|
|
816
|
+
tool._description
|
|
817
|
+
== "An advanced search tool that retrieves relevant documents from the knowledge base using semantic similarity"
|
|
818
|
+
)
|
|
819
|
+
|
|
820
|
+
async def test_rag_tool_name_method_returns_tool_name(
|
|
821
|
+
self, mock_rag_config_with_tool_fields, mock_project_for_tool_fields
|
|
822
|
+
):
|
|
823
|
+
"""Test that the name() method returns the tool_name field."""
|
|
824
|
+
mock_rag_config_with_tool_fields.parent_project.return_value = (
|
|
825
|
+
mock_project_for_tool_fields
|
|
826
|
+
)
|
|
827
|
+
|
|
828
|
+
with patch("kiln_ai.tools.rag_tools.VectorStoreConfig") as mock_vs_config_class:
|
|
829
|
+
mock_vs_config_class.from_id_and_parent_path.return_value = Mock()
|
|
830
|
+
|
|
831
|
+
tool = RagTool("tool_456", mock_rag_config_with_tool_fields)
|
|
832
|
+
|
|
833
|
+
name = await tool.name()
|
|
834
|
+
assert name == "Advanced Document Search"
|
|
835
|
+
|
|
836
|
+
async def test_rag_tool_description_method_returns_tool_description(
|
|
837
|
+
self, mock_rag_config_with_tool_fields, mock_project_for_tool_fields
|
|
838
|
+
):
|
|
839
|
+
"""Test that the description() method returns the tool_description field."""
|
|
840
|
+
mock_rag_config_with_tool_fields.parent_project.return_value = (
|
|
841
|
+
mock_project_for_tool_fields
|
|
842
|
+
)
|
|
843
|
+
|
|
844
|
+
with patch("kiln_ai.tools.rag_tools.VectorStoreConfig") as mock_vs_config_class:
|
|
845
|
+
mock_vs_config_class.from_id_and_parent_path.return_value = Mock()
|
|
846
|
+
|
|
847
|
+
tool = RagTool("tool_456", mock_rag_config_with_tool_fields)
|
|
848
|
+
|
|
849
|
+
description = await tool.description()
|
|
850
|
+
assert (
|
|
851
|
+
description
|
|
852
|
+
== "An advanced search tool that retrieves relevant documents from the knowledge base using semantic similarity"
|
|
853
|
+
)
|
|
854
|
+
|
|
855
|
+
async def test_rag_tool_toolcall_definition_uses_tool_fields(
|
|
856
|
+
self, mock_rag_config_with_tool_fields, mock_project_for_tool_fields
|
|
857
|
+
):
|
|
858
|
+
"""Test that toolcall_definition uses tool_name and tool_description fields."""
|
|
859
|
+
mock_rag_config_with_tool_fields.parent_project.return_value = (
|
|
860
|
+
mock_project_for_tool_fields
|
|
861
|
+
)
|
|
862
|
+
|
|
863
|
+
with patch("kiln_ai.tools.rag_tools.VectorStoreConfig") as mock_vs_config_class:
|
|
864
|
+
mock_vs_config_class.from_id_and_parent_path.return_value = Mock()
|
|
865
|
+
|
|
866
|
+
tool = RagTool("tool_456", mock_rag_config_with_tool_fields)
|
|
867
|
+
|
|
868
|
+
definition = await tool.toolcall_definition()
|
|
869
|
+
|
|
870
|
+
expected_definition = {
|
|
871
|
+
"type": "function",
|
|
872
|
+
"function": {
|
|
873
|
+
"name": "Advanced Document Search",
|
|
874
|
+
"description": "An advanced search tool that retrieves relevant documents from the knowledge base using semantic similarity",
|
|
875
|
+
"parameters": {
|
|
876
|
+
"type": "object",
|
|
877
|
+
"properties": {
|
|
878
|
+
"query": {
|
|
879
|
+
"type": "string",
|
|
880
|
+
"description": "The search query",
|
|
881
|
+
},
|
|
882
|
+
},
|
|
883
|
+
"required": ["query"],
|
|
884
|
+
},
|
|
885
|
+
},
|
|
886
|
+
}
|
|
887
|
+
|
|
888
|
+
assert definition == expected_definition
|
|
889
|
+
|
|
890
|
+
def test_rag_tool_with_unicode_tool_fields(self, mock_project_for_tool_fields):
|
|
891
|
+
"""Test RagTool with Unicode characters in tool_name and tool_description."""
|
|
892
|
+
config = Mock(spec=RagConfig)
|
|
893
|
+
config.id = "rag_config_unicode"
|
|
894
|
+
config.tool_name = "🔍 文档搜索工具"
|
|
895
|
+
config.tool_description = "一个用于搜索文档的高级工具 🚀"
|
|
896
|
+
config.vector_store_config_id = "vector_store_789"
|
|
897
|
+
config.embedding_config_id = "embedding_101"
|
|
898
|
+
config.parent_project.return_value = mock_project_for_tool_fields
|
|
899
|
+
|
|
900
|
+
with patch("kiln_ai.tools.rag_tools.VectorStoreConfig") as mock_vs_config_class:
|
|
901
|
+
mock_vs_config_class.from_id_and_parent_path.return_value = Mock()
|
|
902
|
+
|
|
903
|
+
tool = RagTool("tool_unicode", config)
|
|
904
|
+
assert tool._name == "🔍 文档搜索工具"
|
|
905
|
+
assert tool._description == "一个用于搜索文档的高级工具 🚀"
|
|
906
|
+
|
|
907
|
+
def test_rag_tool_with_multiline_tool_description(
|
|
908
|
+
self, mock_project_for_tool_fields
|
|
909
|
+
):
|
|
910
|
+
"""Test RagTool with multiline tool_description."""
|
|
911
|
+
multiline_description = """This is a comprehensive search tool that:
|
|
912
|
+
- Searches through document collections
|
|
913
|
+
- Uses semantic similarity matching
|
|
914
|
+
- Returns relevant context with metadata
|
|
915
|
+
- Supports various document formats"""
|
|
916
|
+
|
|
917
|
+
config = Mock(spec=RagConfig)
|
|
918
|
+
config.id = "rag_config_multiline"
|
|
919
|
+
config.tool_name = "Comprehensive Search Tool"
|
|
920
|
+
config.tool_description = multiline_description
|
|
921
|
+
config.vector_store_config_id = "vector_store_789"
|
|
922
|
+
config.embedding_config_id = "embedding_101"
|
|
923
|
+
config.parent_project.return_value = mock_project_for_tool_fields
|
|
924
|
+
|
|
925
|
+
with patch("kiln_ai.tools.rag_tools.VectorStoreConfig") as mock_vs_config_class:
|
|
926
|
+
mock_vs_config_class.from_id_and_parent_path.return_value = Mock()
|
|
927
|
+
|
|
928
|
+
tool = RagTool("tool_multiline", config)
|
|
929
|
+
assert tool._description == multiline_description
|