kiln-ai 0.20.1__py3-none-any.whl → 0.22.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +6 -0
- kiln_ai/adapters/adapter_registry.py +43 -226
- kiln_ai/adapters/chunkers/__init__.py +13 -0
- kiln_ai/adapters/chunkers/base_chunker.py +42 -0
- kiln_ai/adapters/chunkers/chunker_registry.py +16 -0
- kiln_ai/adapters/chunkers/fixed_window_chunker.py +39 -0
- kiln_ai/adapters/chunkers/helpers.py +23 -0
- kiln_ai/adapters/chunkers/test_base_chunker.py +63 -0
- kiln_ai/adapters/chunkers/test_chunker_registry.py +28 -0
- kiln_ai/adapters/chunkers/test_fixed_window_chunker.py +346 -0
- kiln_ai/adapters/chunkers/test_helpers.py +75 -0
- kiln_ai/adapters/data_gen/test_data_gen_task.py +9 -3
- kiln_ai/adapters/embedding/__init__.py +0 -0
- kiln_ai/adapters/embedding/base_embedding_adapter.py +44 -0
- kiln_ai/adapters/embedding/embedding_registry.py +32 -0
- kiln_ai/adapters/embedding/litellm_embedding_adapter.py +199 -0
- kiln_ai/adapters/embedding/test_base_embedding_adapter.py +283 -0
- kiln_ai/adapters/embedding/test_embedding_registry.py +166 -0
- kiln_ai/adapters/embedding/test_litellm_embedding_adapter.py +1149 -0
- kiln_ai/adapters/eval/eval_runner.py +6 -2
- kiln_ai/adapters/eval/test_base_eval.py +1 -3
- kiln_ai/adapters/eval/test_g_eval.py +1 -1
- kiln_ai/adapters/extractors/__init__.py +18 -0
- kiln_ai/adapters/extractors/base_extractor.py +72 -0
- kiln_ai/adapters/extractors/encoding.py +20 -0
- kiln_ai/adapters/extractors/extractor_registry.py +44 -0
- kiln_ai/adapters/extractors/extractor_runner.py +112 -0
- kiln_ai/adapters/extractors/litellm_extractor.py +406 -0
- kiln_ai/adapters/extractors/test_base_extractor.py +244 -0
- kiln_ai/adapters/extractors/test_encoding.py +54 -0
- kiln_ai/adapters/extractors/test_extractor_registry.py +181 -0
- kiln_ai/adapters/extractors/test_extractor_runner.py +181 -0
- kiln_ai/adapters/extractors/test_litellm_extractor.py +1290 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +2 -2
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +2 -6
- kiln_ai/adapters/fine_tune/test_together_finetune.py +2 -6
- kiln_ai/adapters/ml_embedding_model_list.py +494 -0
- kiln_ai/adapters/ml_model_list.py +876 -18
- kiln_ai/adapters/model_adapters/litellm_adapter.py +40 -75
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +79 -1
- kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +119 -5
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +9 -3
- kiln_ai/adapters/model_adapters/test_structured_output.py +9 -10
- kiln_ai/adapters/ollama_tools.py +69 -12
- kiln_ai/adapters/provider_tools.py +190 -46
- kiln_ai/adapters/rag/deduplication.py +49 -0
- kiln_ai/adapters/rag/progress.py +252 -0
- kiln_ai/adapters/rag/rag_runners.py +844 -0
- kiln_ai/adapters/rag/test_deduplication.py +195 -0
- kiln_ai/adapters/rag/test_progress.py +785 -0
- kiln_ai/adapters/rag/test_rag_runners.py +2376 -0
- kiln_ai/adapters/remote_config.py +80 -8
- kiln_ai/adapters/test_adapter_registry.py +579 -86
- kiln_ai/adapters/test_ml_embedding_model_list.py +239 -0
- kiln_ai/adapters/test_ml_model_list.py +202 -0
- kiln_ai/adapters/test_ollama_tools.py +340 -1
- kiln_ai/adapters/test_prompt_builders.py +1 -1
- kiln_ai/adapters/test_provider_tools.py +199 -8
- kiln_ai/adapters/test_remote_config.py +551 -56
- kiln_ai/adapters/vector_store/__init__.py +1 -0
- kiln_ai/adapters/vector_store/base_vector_store_adapter.py +83 -0
- kiln_ai/adapters/vector_store/lancedb_adapter.py +389 -0
- kiln_ai/adapters/vector_store/test_base_vector_store.py +160 -0
- kiln_ai/adapters/vector_store/test_lancedb_adapter.py +1841 -0
- kiln_ai/adapters/vector_store/test_vector_store_registry.py +199 -0
- kiln_ai/adapters/vector_store/vector_store_registry.py +33 -0
- kiln_ai/datamodel/__init__.py +16 -13
- kiln_ai/datamodel/basemodel.py +201 -4
- kiln_ai/datamodel/chunk.py +158 -0
- kiln_ai/datamodel/datamodel_enums.py +27 -0
- kiln_ai/datamodel/embedding.py +64 -0
- kiln_ai/datamodel/external_tool_server.py +206 -54
- kiln_ai/datamodel/extraction.py +317 -0
- kiln_ai/datamodel/project.py +33 -1
- kiln_ai/datamodel/rag.py +79 -0
- kiln_ai/datamodel/task.py +5 -0
- kiln_ai/datamodel/task_output.py +41 -11
- kiln_ai/datamodel/test_attachment.py +649 -0
- kiln_ai/datamodel/test_basemodel.py +270 -14
- kiln_ai/datamodel/test_chunk_models.py +317 -0
- kiln_ai/datamodel/test_dataset_split.py +1 -1
- kiln_ai/datamodel/test_datasource.py +50 -0
- kiln_ai/datamodel/test_embedding_models.py +448 -0
- kiln_ai/datamodel/test_eval_model.py +6 -6
- kiln_ai/datamodel/test_external_tool_server.py +534 -152
- kiln_ai/datamodel/test_extraction_chunk.py +206 -0
- kiln_ai/datamodel/test_extraction_model.py +501 -0
- kiln_ai/datamodel/test_rag.py +641 -0
- kiln_ai/datamodel/test_task.py +35 -1
- kiln_ai/datamodel/test_tool_id.py +187 -1
- kiln_ai/datamodel/test_vector_store.py +320 -0
- kiln_ai/datamodel/tool_id.py +58 -0
- kiln_ai/datamodel/vector_store.py +141 -0
- kiln_ai/tools/base_tool.py +12 -3
- kiln_ai/tools/built_in_tools/math_tools.py +12 -4
- kiln_ai/tools/kiln_task_tool.py +158 -0
- kiln_ai/tools/mcp_server_tool.py +2 -2
- kiln_ai/tools/mcp_session_manager.py +51 -22
- kiln_ai/tools/rag_tools.py +164 -0
- kiln_ai/tools/test_kiln_task_tool.py +527 -0
- kiln_ai/tools/test_mcp_server_tool.py +4 -15
- kiln_ai/tools/test_mcp_session_manager.py +187 -227
- kiln_ai/tools/test_rag_tools.py +929 -0
- kiln_ai/tools/test_tool_registry.py +290 -7
- kiln_ai/tools/tool_registry.py +69 -16
- kiln_ai/utils/__init__.py +3 -0
- kiln_ai/utils/async_job_runner.py +62 -17
- kiln_ai/utils/config.py +2 -2
- kiln_ai/utils/env.py +15 -0
- kiln_ai/utils/filesystem.py +14 -0
- kiln_ai/utils/filesystem_cache.py +60 -0
- kiln_ai/utils/litellm.py +94 -0
- kiln_ai/utils/lock.py +100 -0
- kiln_ai/utils/mime_type.py +38 -0
- kiln_ai/utils/open_ai_types.py +19 -2
- kiln_ai/utils/pdf_utils.py +59 -0
- kiln_ai/utils/test_async_job_runner.py +151 -35
- kiln_ai/utils/test_env.py +142 -0
- kiln_ai/utils/test_filesystem_cache.py +316 -0
- kiln_ai/utils/test_litellm.py +206 -0
- kiln_ai/utils/test_lock.py +185 -0
- kiln_ai/utils/test_mime_type.py +66 -0
- kiln_ai/utils/test_open_ai_types.py +88 -12
- kiln_ai/utils/test_pdf_utils.py +86 -0
- kiln_ai/utils/test_uuid.py +111 -0
- kiln_ai/utils/test_validation.py +524 -0
- kiln_ai/utils/uuid.py +9 -0
- kiln_ai/utils/validation.py +90 -0
- {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/METADATA +9 -1
- kiln_ai-0.22.0.dist-info/RECORD +213 -0
- kiln_ai-0.20.1.dist-info/RECORD +0 -138
- {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -0,0 +1,641 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
from pydantic import ValidationError
|
|
3
|
+
|
|
4
|
+
from kiln_ai.datamodel.project import Project
|
|
5
|
+
from kiln_ai.datamodel.rag import RagConfig
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@pytest.fixture
|
|
9
|
+
def mock_project(tmp_path):
|
|
10
|
+
project_path = tmp_path / "test_project" / "project.kiln"
|
|
11
|
+
project_path.parent.mkdir()
|
|
12
|
+
|
|
13
|
+
project = Project(name="Test Project", path=project_path)
|
|
14
|
+
project.save_to_file()
|
|
15
|
+
|
|
16
|
+
return project
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@pytest.fixture
|
|
20
|
+
def sample_rag_config_data():
|
|
21
|
+
"""Sample data for creating a RagConfig instance."""
|
|
22
|
+
return {
|
|
23
|
+
"name": "Test RAG Config",
|
|
24
|
+
"description": "A test RAG config for testing purposes",
|
|
25
|
+
"tool_name": "test_search_tool",
|
|
26
|
+
"tool_description": "A test search tool for document retrieval",
|
|
27
|
+
"extractor_config_id": "extractor123",
|
|
28
|
+
"chunker_config_id": "chunker456",
|
|
29
|
+
"embedding_config_id": "embedding789",
|
|
30
|
+
"vector_store_config_id": "vector_store123",
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def test_rag_config_valid_creation(sample_rag_config_data):
|
|
35
|
+
"""Test creating a RagConfig with all required fields."""
|
|
36
|
+
rag_config = RagConfig(**sample_rag_config_data)
|
|
37
|
+
|
|
38
|
+
assert rag_config.name == "Test RAG Config"
|
|
39
|
+
assert rag_config.description == "A test RAG config for testing purposes"
|
|
40
|
+
assert rag_config.tool_name == "test_search_tool"
|
|
41
|
+
assert rag_config.tool_description == "A test search tool for document retrieval"
|
|
42
|
+
assert rag_config.extractor_config_id == "extractor123"
|
|
43
|
+
assert rag_config.chunker_config_id == "chunker456"
|
|
44
|
+
assert rag_config.embedding_config_id == "embedding789"
|
|
45
|
+
assert rag_config.vector_store_config_id == "vector_store123"
|
|
46
|
+
assert not rag_config.is_archived # Default value
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def test_rag_config_minimal_creation():
|
|
50
|
+
"""Test creating a RagConfig with only required fields."""
|
|
51
|
+
rag_config = RagConfig(
|
|
52
|
+
name="Minimal RAG Config",
|
|
53
|
+
tool_name="minimal_search_tool",
|
|
54
|
+
tool_description="A minimal search tool for testing",
|
|
55
|
+
extractor_config_id="extractor123",
|
|
56
|
+
chunker_config_id="chunker456",
|
|
57
|
+
embedding_config_id="embedding789",
|
|
58
|
+
vector_store_config_id="vector_store123",
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
assert rag_config.name == "Minimal RAG Config"
|
|
62
|
+
assert rag_config.description is None
|
|
63
|
+
assert rag_config.tool_name == "minimal_search_tool"
|
|
64
|
+
assert rag_config.tool_description == "A minimal search tool for testing"
|
|
65
|
+
assert rag_config.extractor_config_id == "extractor123"
|
|
66
|
+
assert rag_config.chunker_config_id == "chunker456"
|
|
67
|
+
assert rag_config.embedding_config_id == "embedding789"
|
|
68
|
+
assert rag_config.vector_store_config_id == "vector_store123"
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def test_rag_config_missing_required_fields():
|
|
72
|
+
"""Test that missing required fields raise ValidationError."""
|
|
73
|
+
# Test missing name
|
|
74
|
+
with pytest.raises(ValidationError) as exc_info:
|
|
75
|
+
RagConfig(
|
|
76
|
+
tool_name="test_tool",
|
|
77
|
+
tool_description="A test tool for missing required fields testing",
|
|
78
|
+
extractor_config_id="extractor123",
|
|
79
|
+
chunker_config_id="chunker456",
|
|
80
|
+
embedding_config_id="embedding789",
|
|
81
|
+
vector_store_config_id="vector_store123",
|
|
82
|
+
)
|
|
83
|
+
errors = exc_info.value.errors()
|
|
84
|
+
assert any(error["loc"][0] == "name" for error in errors)
|
|
85
|
+
|
|
86
|
+
# Test missing extractor_config_id
|
|
87
|
+
with pytest.raises(ValidationError) as exc_info:
|
|
88
|
+
RagConfig(
|
|
89
|
+
name="test_config",
|
|
90
|
+
chunker_config_id="chunker456",
|
|
91
|
+
embedding_config_id="embedding789",
|
|
92
|
+
vector_store_config_id="vector_store123",
|
|
93
|
+
)
|
|
94
|
+
errors = exc_info.value.errors()
|
|
95
|
+
assert any(error["loc"][0] == "extractor_config_id" for error in errors)
|
|
96
|
+
|
|
97
|
+
# Test missing chunker_config_id
|
|
98
|
+
with pytest.raises(ValidationError) as exc_info:
|
|
99
|
+
RagConfig(
|
|
100
|
+
name="Test Config",
|
|
101
|
+
tool_name="test_tool",
|
|
102
|
+
tool_description="A test tool for chunker config ID testing",
|
|
103
|
+
extractor_config_id="extractor123",
|
|
104
|
+
embedding_config_id="embedding789",
|
|
105
|
+
vector_store_config_id="vector_store123",
|
|
106
|
+
)
|
|
107
|
+
errors = exc_info.value.errors()
|
|
108
|
+
assert any(error["loc"][0] == "chunker_config_id" for error in errors)
|
|
109
|
+
|
|
110
|
+
# Test missing embedding_config_id
|
|
111
|
+
with pytest.raises(ValidationError) as exc_info:
|
|
112
|
+
RagConfig(
|
|
113
|
+
name="Test Config",
|
|
114
|
+
tool_name="test_tool",
|
|
115
|
+
tool_description="A test tool for embedding config ID testing",
|
|
116
|
+
extractor_config_id="extractor123",
|
|
117
|
+
chunker_config_id="chunker456",
|
|
118
|
+
vector_store_config_id="vector_store123",
|
|
119
|
+
)
|
|
120
|
+
errors = exc_info.value.errors()
|
|
121
|
+
assert any(error["loc"][0] == "embedding_config_id" for error in errors)
|
|
122
|
+
|
|
123
|
+
# Test missing vector_store_config_id
|
|
124
|
+
with pytest.raises(ValidationError) as exc_info:
|
|
125
|
+
RagConfig(
|
|
126
|
+
name="Test Config",
|
|
127
|
+
tool_name="test_tool",
|
|
128
|
+
tool_description="A test tool for vector store config ID testing",
|
|
129
|
+
extractor_config_id="extractor123",
|
|
130
|
+
chunker_config_id="chunker456",
|
|
131
|
+
embedding_config_id="embedding789",
|
|
132
|
+
)
|
|
133
|
+
errors = exc_info.value.errors()
|
|
134
|
+
assert any(error["loc"][0] == "vector_store_config_id" for error in errors)
|
|
135
|
+
|
|
136
|
+
# missing tool_name
|
|
137
|
+
with pytest.raises(ValidationError) as exc_info:
|
|
138
|
+
RagConfig(
|
|
139
|
+
name="Test Config",
|
|
140
|
+
tool_description="A test tool for tool name testing",
|
|
141
|
+
extractor_config_id="extractor123",
|
|
142
|
+
chunker_config_id="chunker456",
|
|
143
|
+
embedding_config_id="embedding789",
|
|
144
|
+
vector_store_config_id="vector_store123",
|
|
145
|
+
)
|
|
146
|
+
errors = exc_info.value.errors()
|
|
147
|
+
assert any(error["loc"][0] == "tool_name" for error in errors)
|
|
148
|
+
|
|
149
|
+
# missing tool_description
|
|
150
|
+
with pytest.raises(ValidationError) as exc_info:
|
|
151
|
+
RagConfig(
|
|
152
|
+
name="Test Config",
|
|
153
|
+
tool_name="test_tool",
|
|
154
|
+
extractor_config_id="extractor123",
|
|
155
|
+
chunker_config_id="chunker456",
|
|
156
|
+
embedding_config_id="embedding789",
|
|
157
|
+
vector_store_config_id="vector_store123",
|
|
158
|
+
)
|
|
159
|
+
errors = exc_info.value.errors()
|
|
160
|
+
assert any(error["loc"][0] == "tool_description" for error in errors)
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def test_rag_config_description_optional():
|
|
164
|
+
"""Test that description field is optional and can be None."""
|
|
165
|
+
rag_config = RagConfig(
|
|
166
|
+
name="Test Config",
|
|
167
|
+
description=None,
|
|
168
|
+
tool_name="test_tool",
|
|
169
|
+
tool_description="A test tool for description testing",
|
|
170
|
+
extractor_config_id="extractor123",
|
|
171
|
+
chunker_config_id="chunker456",
|
|
172
|
+
embedding_config_id="embedding789",
|
|
173
|
+
vector_store_config_id="vector_store123",
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
assert rag_config.description is None
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def test_rag_config_description_string():
|
|
180
|
+
"""Test that description field accepts string values."""
|
|
181
|
+
rag_config = RagConfig(
|
|
182
|
+
name="Test Config",
|
|
183
|
+
description="A detailed description of the RAG config",
|
|
184
|
+
tool_name="test_tool",
|
|
185
|
+
tool_description="A test tool for description string testing",
|
|
186
|
+
extractor_config_id="extractor123",
|
|
187
|
+
chunker_config_id="chunker456",
|
|
188
|
+
embedding_config_id="embedding789",
|
|
189
|
+
vector_store_config_id="vector_store123",
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
assert rag_config.description == "A detailed description of the RAG config"
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def test_rag_config_id_generation():
|
|
196
|
+
"""Test that RagConfig generates an ID automatically."""
|
|
197
|
+
rag_config = RagConfig(
|
|
198
|
+
name="Test Config",
|
|
199
|
+
tool_name="test_tool",
|
|
200
|
+
tool_description="A test tool for ID generation",
|
|
201
|
+
extractor_config_id="extractor123",
|
|
202
|
+
chunker_config_id="chunker456",
|
|
203
|
+
embedding_config_id="embedding789",
|
|
204
|
+
vector_store_config_id="vector_store123",
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
assert rag_config.id is not None
|
|
208
|
+
assert isinstance(rag_config.id, str)
|
|
209
|
+
assert len(rag_config.id) == 12 # ID should be 12 digits
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def test_rag_config_inheritance():
|
|
213
|
+
"""Test that RagConfig inherits from KilnParentedModel."""
|
|
214
|
+
rag_config = RagConfig(
|
|
215
|
+
name="Test Config",
|
|
216
|
+
tool_name="test_search_tool",
|
|
217
|
+
tool_description="A test search tool for inheritance testing",
|
|
218
|
+
extractor_config_id="extractor123",
|
|
219
|
+
chunker_config_id="chunker456",
|
|
220
|
+
embedding_config_id="embedding789",
|
|
221
|
+
vector_store_config_id="vector_store123",
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
# Test that it has the expected base class attributes
|
|
225
|
+
assert hasattr(rag_config, "v") # schema version
|
|
226
|
+
assert hasattr(rag_config, "id") # unique identifier
|
|
227
|
+
assert hasattr(rag_config, "path") # file system path
|
|
228
|
+
assert hasattr(rag_config, "created_at") # creation timestamp
|
|
229
|
+
assert hasattr(rag_config, "created_by") # creator user ID
|
|
230
|
+
assert hasattr(rag_config, "parent") # parent reference
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
def test_rag_config_model_type():
|
|
234
|
+
"""Test that RagConfig has the correct model type."""
|
|
235
|
+
rag_config = RagConfig(
|
|
236
|
+
name="Test Config",
|
|
237
|
+
tool_name="test_search_tool",
|
|
238
|
+
tool_description="A test search tool for model type testing",
|
|
239
|
+
extractor_config_id="extractor123",
|
|
240
|
+
chunker_config_id="chunker456",
|
|
241
|
+
embedding_config_id="embedding789",
|
|
242
|
+
vector_store_config_id="vector_store123",
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
assert rag_config.model_type == "rag_config"
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
def test_rag_config_config_id_types():
|
|
249
|
+
"""Test that config IDs can be various string formats."""
|
|
250
|
+
# Test with numeric strings
|
|
251
|
+
rag_config = RagConfig(
|
|
252
|
+
name="Test Config",
|
|
253
|
+
tool_name="test_search_tool",
|
|
254
|
+
tool_description="A test search tool for config ID testing",
|
|
255
|
+
extractor_config_id="123",
|
|
256
|
+
chunker_config_id="456",
|
|
257
|
+
embedding_config_id="789",
|
|
258
|
+
vector_store_config_id="999",
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
assert rag_config.extractor_config_id == "123"
|
|
262
|
+
assert rag_config.chunker_config_id == "456"
|
|
263
|
+
assert rag_config.embedding_config_id == "789"
|
|
264
|
+
assert rag_config.vector_store_config_id == "999"
|
|
265
|
+
|
|
266
|
+
# Test with UUID-like strings
|
|
267
|
+
rag_config = RagConfig(
|
|
268
|
+
name="Test Config",
|
|
269
|
+
tool_name="test_search_tool",
|
|
270
|
+
tool_description="A test search tool for UUID-like config ID testing",
|
|
271
|
+
extractor_config_id="extractor-123-456-789",
|
|
272
|
+
chunker_config_id="chunker-abc-def-ghi",
|
|
273
|
+
embedding_config_id="embedding-xyz-uvw-rst",
|
|
274
|
+
vector_store_config_id="vector-store-abc-def-ghi",
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
assert rag_config.extractor_config_id == "extractor-123-456-789"
|
|
278
|
+
assert rag_config.chunker_config_id == "chunker-abc-def-ghi"
|
|
279
|
+
assert rag_config.embedding_config_id == "embedding-xyz-uvw-rst"
|
|
280
|
+
assert rag_config.vector_store_config_id == "vector-store-abc-def-ghi"
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
def test_rag_config_serialization():
|
|
284
|
+
"""Test that RagConfig can be serialized and deserialized."""
|
|
285
|
+
original_config = RagConfig(
|
|
286
|
+
name="Test Config",
|
|
287
|
+
description="A test config",
|
|
288
|
+
tool_name="test_search_tool",
|
|
289
|
+
tool_description="A test search tool for serialization testing",
|
|
290
|
+
extractor_config_id="extractor123",
|
|
291
|
+
chunker_config_id="chunker456",
|
|
292
|
+
embedding_config_id="embedding789",
|
|
293
|
+
vector_store_config_id="vector_store123",
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
# Serialize to dict
|
|
297
|
+
config_dict = original_config.model_dump()
|
|
298
|
+
|
|
299
|
+
# Deserialize back to object
|
|
300
|
+
deserialized_config = RagConfig(**config_dict)
|
|
301
|
+
|
|
302
|
+
assert deserialized_config.name == original_config.name
|
|
303
|
+
assert deserialized_config.description == original_config.description
|
|
304
|
+
assert deserialized_config.tool_name == original_config.tool_name
|
|
305
|
+
assert deserialized_config.tool_description == original_config.tool_description
|
|
306
|
+
assert (
|
|
307
|
+
deserialized_config.extractor_config_id == original_config.extractor_config_id
|
|
308
|
+
)
|
|
309
|
+
assert deserialized_config.chunker_config_id == original_config.chunker_config_id
|
|
310
|
+
assert (
|
|
311
|
+
deserialized_config.embedding_config_id == original_config.embedding_config_id
|
|
312
|
+
)
|
|
313
|
+
assert (
|
|
314
|
+
deserialized_config.vector_store_config_id
|
|
315
|
+
== original_config.vector_store_config_id
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
def test_rag_config_default_values():
|
|
320
|
+
"""Test that RagConfig has appropriate default values."""
|
|
321
|
+
rag_config = RagConfig(
|
|
322
|
+
name="Test Config",
|
|
323
|
+
tool_name="test_search_tool",
|
|
324
|
+
tool_description="A test search tool for default values testing",
|
|
325
|
+
extractor_config_id="extractor123",
|
|
326
|
+
chunker_config_id="chunker456",
|
|
327
|
+
embedding_config_id="embedding789",
|
|
328
|
+
vector_store_config_id="vector_store123",
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
# Test default values
|
|
332
|
+
assert rag_config.description is None
|
|
333
|
+
assert rag_config.v == 1 # schema version default
|
|
334
|
+
assert rag_config.id is not None # auto-generated ID
|
|
335
|
+
assert rag_config.path is None # no path by default
|
|
336
|
+
assert rag_config.parent is None # no parent by default
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
def test_project_has_rag_configs(mock_project):
|
|
340
|
+
"""Test relationship between project and RagConfig."""
|
|
341
|
+
# create 2 rag configs
|
|
342
|
+
rag_config_1 = RagConfig(
|
|
343
|
+
parent=mock_project,
|
|
344
|
+
name="Test Config 1",
|
|
345
|
+
tool_name="test_search_tool_1",
|
|
346
|
+
tool_description="First test search tool for project relationship testing",
|
|
347
|
+
extractor_config_id="extractor123",
|
|
348
|
+
chunker_config_id="chunker456",
|
|
349
|
+
embedding_config_id="embedding789",
|
|
350
|
+
vector_store_config_id="vector_store123",
|
|
351
|
+
)
|
|
352
|
+
|
|
353
|
+
rag_config_2 = RagConfig(
|
|
354
|
+
parent=mock_project,
|
|
355
|
+
name="Test Config 2",
|
|
356
|
+
tool_name="test_search_tool_2",
|
|
357
|
+
tool_description="Second test search tool for project relationship testing",
|
|
358
|
+
extractor_config_id="extractor123",
|
|
359
|
+
chunker_config_id="chunker456",
|
|
360
|
+
embedding_config_id="embedding789",
|
|
361
|
+
vector_store_config_id="vector_store456",
|
|
362
|
+
)
|
|
363
|
+
|
|
364
|
+
# save the rag configs
|
|
365
|
+
rag_config_1.save_to_file()
|
|
366
|
+
rag_config_2.save_to_file()
|
|
367
|
+
|
|
368
|
+
# check that the project has the rag configs
|
|
369
|
+
child_rag_configs = mock_project.rag_configs()
|
|
370
|
+
assert len(child_rag_configs) == 2
|
|
371
|
+
|
|
372
|
+
for rag_config in child_rag_configs:
|
|
373
|
+
assert rag_config.id in [rag_config_1.id, rag_config_2.id]
|
|
374
|
+
|
|
375
|
+
|
|
376
|
+
def test_parent_project(mock_project):
|
|
377
|
+
"""Test that parent project is returned correctly."""
|
|
378
|
+
rag_config = RagConfig(
|
|
379
|
+
parent=mock_project,
|
|
380
|
+
name="Test Config",
|
|
381
|
+
tool_name="test_search_tool",
|
|
382
|
+
tool_description="A test search tool for parent project testing",
|
|
383
|
+
extractor_config_id="extractor123",
|
|
384
|
+
chunker_config_id="chunker456",
|
|
385
|
+
embedding_config_id="embedding789",
|
|
386
|
+
vector_store_config_id="vector_store123",
|
|
387
|
+
)
|
|
388
|
+
|
|
389
|
+
assert rag_config.parent_project() is mock_project
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
def test_rag_config_parent_project_none():
|
|
393
|
+
"""Test that parent project is None if not set."""
|
|
394
|
+
rag_config = RagConfig(
|
|
395
|
+
name="Test Config",
|
|
396
|
+
tool_name="test_search_tool",
|
|
397
|
+
tool_description="A test search tool for parent project none testing",
|
|
398
|
+
extractor_config_id="extractor123",
|
|
399
|
+
chunker_config_id="chunker456",
|
|
400
|
+
embedding_config_id="embedding789",
|
|
401
|
+
vector_store_config_id="vector_store123",
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
assert rag_config.parent_project() is None
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
def test_rag_config_tags_with_none():
|
|
408
|
+
"""Test that tags field can be explicitly set to None."""
|
|
409
|
+
rag_config = RagConfig(
|
|
410
|
+
name="Test Config",
|
|
411
|
+
tool_name="test_search_tool",
|
|
412
|
+
tool_description="A test search tool for tags none testing",
|
|
413
|
+
extractor_config_id="extractor123",
|
|
414
|
+
chunker_config_id="chunker456",
|
|
415
|
+
embedding_config_id="embedding789",
|
|
416
|
+
vector_store_config_id="vector_store123",
|
|
417
|
+
tags=None,
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
assert rag_config.tags is None
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
def test_rag_config_tags_with_valid_tags():
|
|
424
|
+
"""Test that tags field accepts a valid list of strings."""
|
|
425
|
+
tags = ["python", "ml", "backend", "api"]
|
|
426
|
+
rag_config = RagConfig(
|
|
427
|
+
name="Test Config",
|
|
428
|
+
tool_name="test_search_tool",
|
|
429
|
+
tool_description="A test search tool for valid tags testing",
|
|
430
|
+
extractor_config_id="extractor123",
|
|
431
|
+
chunker_config_id="chunker456",
|
|
432
|
+
embedding_config_id="embedding789",
|
|
433
|
+
vector_store_config_id="vector_store123",
|
|
434
|
+
tags=tags,
|
|
435
|
+
)
|
|
436
|
+
|
|
437
|
+
assert rag_config.tags == tags
|
|
438
|
+
assert isinstance(rag_config.tags, list)
|
|
439
|
+
assert all(isinstance(tag, str) for tag in rag_config.tags)
|
|
440
|
+
|
|
441
|
+
|
|
442
|
+
@pytest.mark.parametrize(
|
|
443
|
+
"invalid_tags,expected_error",
|
|
444
|
+
[
|
|
445
|
+
([], "Tags cannot be an empty list"),
|
|
446
|
+
(
|
|
447
|
+
["python", "with spaces", "ml"],
|
|
448
|
+
"Tags cannot contain spaces. Try underscores.",
|
|
449
|
+
),
|
|
450
|
+
(["python", " ", "ml"], "Tags cannot contain spaces. Try underscores."),
|
|
451
|
+
(["python", " leading_space"], "Tags cannot contain spaces. Try underscores."),
|
|
452
|
+
(["trailing_space ", "ml"], "Tags cannot contain spaces. Try underscores."),
|
|
453
|
+
(["", "ml"], "Tags cannot be empty."),
|
|
454
|
+
],
|
|
455
|
+
)
|
|
456
|
+
def test_rag_config_tags_invalid(invalid_tags, expected_error):
|
|
457
|
+
"""Test that tags field rejects invalid inputs."""
|
|
458
|
+
with pytest.raises(ValueError) as exc_info:
|
|
459
|
+
RagConfig(
|
|
460
|
+
name="Test Config",
|
|
461
|
+
tool_name="test_search_tool",
|
|
462
|
+
tool_description="A test search tool for invalid tags testing",
|
|
463
|
+
extractor_config_id="extractor123",
|
|
464
|
+
chunker_config_id="chunker456",
|
|
465
|
+
embedding_config_id="embedding789",
|
|
466
|
+
vector_store_config_id="vector_store123",
|
|
467
|
+
tags=invalid_tags,
|
|
468
|
+
)
|
|
469
|
+
assert expected_error in str(exc_info.value)
|
|
470
|
+
|
|
471
|
+
|
|
472
|
+
def test_rag_config_tool_description_string_values():
|
|
473
|
+
"""Test that tool_description accepts various string values."""
|
|
474
|
+
test_cases = [
|
|
475
|
+
"Simple description",
|
|
476
|
+
"A very detailed description of what this tool does and how it should be used by the model.",
|
|
477
|
+
"Description with\nnewlines\nand special chars!@#$%^&*()",
|
|
478
|
+
"Multi-line description\nwith detailed explanation\nof tool capabilities",
|
|
479
|
+
"Description with Unicode: 测试描述 🚀",
|
|
480
|
+
]
|
|
481
|
+
|
|
482
|
+
for tool_description in test_cases:
|
|
483
|
+
rag_config = RagConfig(
|
|
484
|
+
name="Test Config",
|
|
485
|
+
tool_name="test_tool",
|
|
486
|
+
tool_description=tool_description,
|
|
487
|
+
extractor_config_id="extractor123",
|
|
488
|
+
chunker_config_id="chunker456",
|
|
489
|
+
embedding_config_id="embedding789",
|
|
490
|
+
vector_store_config_id="vector_store123",
|
|
491
|
+
)
|
|
492
|
+
assert rag_config.tool_description == tool_description
|
|
493
|
+
|
|
494
|
+
|
|
495
|
+
def test_rag_config_tool_fields_in_model_dump():
|
|
496
|
+
"""Test that tool_name and tool_description are included in model serialization."""
|
|
497
|
+
rag_config = RagConfig(
|
|
498
|
+
name="Test Config",
|
|
499
|
+
tool_name="serialization_test_tool",
|
|
500
|
+
tool_description="A tool for testing serialization of tool fields",
|
|
501
|
+
extractor_config_id="extractor123",
|
|
502
|
+
chunker_config_id="chunker456",
|
|
503
|
+
embedding_config_id="embedding789",
|
|
504
|
+
vector_store_config_id="vector_store123",
|
|
505
|
+
)
|
|
506
|
+
|
|
507
|
+
serialized = rag_config.model_dump()
|
|
508
|
+
|
|
509
|
+
assert "tool_name" in serialized
|
|
510
|
+
assert "tool_description" in serialized
|
|
511
|
+
assert serialized["tool_name"] == "serialization_test_tool"
|
|
512
|
+
assert (
|
|
513
|
+
serialized["tool_description"]
|
|
514
|
+
== "A tool for testing serialization of tool fields"
|
|
515
|
+
)
|
|
516
|
+
|
|
517
|
+
|
|
518
|
+
@pytest.mark.parametrize(
|
|
519
|
+
"tool_name,tool_description,expected_error",
|
|
520
|
+
[
|
|
521
|
+
# Empty tool_name
|
|
522
|
+
("", "Valid description", "Tool name cannot be empty"),
|
|
523
|
+
# Empty tool_description
|
|
524
|
+
("valid_tool", "", "Tool description cannot be empty"),
|
|
525
|
+
# Whitespace-only tool_name
|
|
526
|
+
(" ", "Valid description", "Tool name cannot be empty"),
|
|
527
|
+
# Whitespace-only tool_description
|
|
528
|
+
("valid_tool", " ", "Tool description cannot be empty"),
|
|
529
|
+
# Tab and newline whitespace
|
|
530
|
+
("\t\n", "Valid description", "Tool name cannot be empty"),
|
|
531
|
+
("valid_tool", "\t\n", "Tool description cannot be empty"),
|
|
532
|
+
],
|
|
533
|
+
)
|
|
534
|
+
def test_rag_config_tool_fields_validation_edge_cases(
|
|
535
|
+
tool_name, tool_description, expected_error
|
|
536
|
+
):
|
|
537
|
+
"""Test edge cases for tool_name and tool_description validation."""
|
|
538
|
+
with pytest.raises(ValueError, match=expected_error):
|
|
539
|
+
RagConfig(
|
|
540
|
+
name="Test Config",
|
|
541
|
+
tool_name=tool_name,
|
|
542
|
+
tool_description=tool_description,
|
|
543
|
+
extractor_config_id="extractor123",
|
|
544
|
+
chunker_config_id="chunker456",
|
|
545
|
+
embedding_config_id="embedding789",
|
|
546
|
+
vector_store_config_id="vector_store123",
|
|
547
|
+
)
|
|
548
|
+
|
|
549
|
+
|
|
550
|
+
@pytest.mark.parametrize(
|
|
551
|
+
"tool_name,expected_error",
|
|
552
|
+
[
|
|
553
|
+
("Invalid Tool Name", "Tool name must be in snake_case"),
|
|
554
|
+
("", "Tool name cannot be empty"),
|
|
555
|
+
("a" * 65, "Tool name must be less than 64 characters long"),
|
|
556
|
+
],
|
|
557
|
+
)
|
|
558
|
+
def test_rag_config_tool_name_validation(tool_name, expected_error):
|
|
559
|
+
"""Test that tool_name validation works."""
|
|
560
|
+
# Not exhaustive, just an integration test that the validator is called. The validator is tested in utils/test_validation.py.
|
|
561
|
+
with pytest.raises(ValueError) as exc_info:
|
|
562
|
+
RagConfig(
|
|
563
|
+
name="Test Config",
|
|
564
|
+
tool_name=tool_name,
|
|
565
|
+
tool_description="A test search tool for invalid tool name testing",
|
|
566
|
+
extractor_config_id="extractor123",
|
|
567
|
+
chunker_config_id="chunker456",
|
|
568
|
+
embedding_config_id="embedding789",
|
|
569
|
+
vector_store_config_id="vector_store123",
|
|
570
|
+
)
|
|
571
|
+
assert expected_error in str(exc_info.value)
|
|
572
|
+
|
|
573
|
+
|
|
574
|
+
def test_rag_config_is_archived_field():
|
|
575
|
+
"""Test the is_archived field functionality."""
|
|
576
|
+
# Test default value
|
|
577
|
+
rag_config = RagConfig(
|
|
578
|
+
name="Test RAG Config",
|
|
579
|
+
tool_name="test_search_tool",
|
|
580
|
+
tool_description="A test search tool",
|
|
581
|
+
extractor_config_id="extractor123",
|
|
582
|
+
chunker_config_id="chunker456",
|
|
583
|
+
embedding_config_id="embedding789",
|
|
584
|
+
vector_store_config_id="vector_store123",
|
|
585
|
+
)
|
|
586
|
+
assert not rag_config.is_archived
|
|
587
|
+
|
|
588
|
+
# Test explicit False
|
|
589
|
+
rag_config = RagConfig(
|
|
590
|
+
name="Test RAG Config",
|
|
591
|
+
tool_name="test_search_tool",
|
|
592
|
+
tool_description="A test search tool",
|
|
593
|
+
extractor_config_id="extractor123",
|
|
594
|
+
chunker_config_id="chunker456",
|
|
595
|
+
embedding_config_id="embedding789",
|
|
596
|
+
vector_store_config_id="vector_store123",
|
|
597
|
+
is_archived=False,
|
|
598
|
+
)
|
|
599
|
+
assert not rag_config.is_archived
|
|
600
|
+
|
|
601
|
+
# Test explicit True
|
|
602
|
+
rag_config = RagConfig(
|
|
603
|
+
name="Test RAG Config",
|
|
604
|
+
tool_name="test_search_tool",
|
|
605
|
+
tool_description="A test search tool",
|
|
606
|
+
extractor_config_id="extractor123",
|
|
607
|
+
chunker_config_id="chunker456",
|
|
608
|
+
embedding_config_id="embedding789",
|
|
609
|
+
vector_store_config_id="vector_store123",
|
|
610
|
+
is_archived=True,
|
|
611
|
+
)
|
|
612
|
+
assert rag_config.is_archived
|
|
613
|
+
|
|
614
|
+
|
|
615
|
+
def test_rag_config_archived_persistence(mock_project, sample_rag_config_data):
|
|
616
|
+
"""Test that is_archived field persists when saving and loading."""
|
|
617
|
+
# Create archived config
|
|
618
|
+
rag_config = RagConfig(
|
|
619
|
+
parent=mock_project,
|
|
620
|
+
is_archived=True,
|
|
621
|
+
**sample_rag_config_data,
|
|
622
|
+
)
|
|
623
|
+
rag_config.save_to_file()
|
|
624
|
+
|
|
625
|
+
assert rag_config.id
|
|
626
|
+
|
|
627
|
+
# Load it back
|
|
628
|
+
loaded_config = RagConfig.from_id_and_parent_path(rag_config.id, mock_project.path)
|
|
629
|
+
assert loaded_config is not None
|
|
630
|
+
assert loaded_config.is_archived
|
|
631
|
+
|
|
632
|
+
# Test unarchiving
|
|
633
|
+
loaded_config.is_archived = False
|
|
634
|
+
loaded_config.save_to_file()
|
|
635
|
+
|
|
636
|
+
# Load it back again
|
|
637
|
+
reloaded_config = RagConfig.from_id_and_parent_path(
|
|
638
|
+
rag_config.id, mock_project.path
|
|
639
|
+
)
|
|
640
|
+
assert reloaded_config is not None
|
|
641
|
+
assert not reloaded_config.is_archived
|
kiln_ai/datamodel/test_task.py
CHANGED
|
@@ -254,7 +254,7 @@ def test_run_config_upgrade_old_entries():
|
|
|
254
254
|
},
|
|
255
255
|
"prompt": {
|
|
256
256
|
"name": "Dazzling Unicorn",
|
|
257
|
-
"description": "Frozen copy of prompt 'simple_prompt_builder'
|
|
257
|
+
"description": "Frozen copy of prompt 'simple_prompt_builder'.",
|
|
258
258
|
"generator_id": "simple_prompt_builder",
|
|
259
259
|
"prompt": "Generate a joke, given a theme. The theme will be provided as a word or phrase as the input to the model. The assistant should output a joke that is funny and relevant to the theme. If a style is provided, the joke should be in that style. The output should include a setup and punchline.\n\nYour response should respect the following requirements:\n1) Keep the joke on topic. If the user specifies a theme, the joke must be related to that theme.\n2) Avoid any jokes that are offensive or inappropriate. Keep the joke clean and appropriate for all audiences.\n3) Make the joke funny and engaging. It should be something that someone would want to tell to their friends. Something clever, not just a simple pun.\n",
|
|
260
260
|
"chain_of_thought_instructions": None,
|
|
@@ -296,3 +296,37 @@ def test_run_config_upgrade_old_entries():
|
|
|
296
296
|
def test_task_name_unicode_name():
|
|
297
297
|
task = Task(name="你好", instruction="Do something")
|
|
298
298
|
assert task.name == "你好"
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
def test_task_default_run_config_id_property(tmp_path):
|
|
302
|
+
"""Test that default_run_config_id can be set and retrieved."""
|
|
303
|
+
|
|
304
|
+
# Create a task
|
|
305
|
+
task = Task(
|
|
306
|
+
name="Test Task", instruction="Test instruction", path=tmp_path / "task.kiln"
|
|
307
|
+
)
|
|
308
|
+
task.save_to_file()
|
|
309
|
+
|
|
310
|
+
# Create a run config for the task
|
|
311
|
+
run_config = TaskRunConfig(
|
|
312
|
+
name="Test Config",
|
|
313
|
+
run_config_properties=RunConfigProperties(
|
|
314
|
+
model_name="gpt-4",
|
|
315
|
+
model_provider_name="openai",
|
|
316
|
+
prompt_id=PromptGenerators.SIMPLE,
|
|
317
|
+
structured_output_mode=StructuredOutputMode.json_schema,
|
|
318
|
+
),
|
|
319
|
+
parent=task,
|
|
320
|
+
)
|
|
321
|
+
run_config.save_to_file()
|
|
322
|
+
|
|
323
|
+
# Test None default (should be valid)
|
|
324
|
+
assert task.default_run_config_id is None
|
|
325
|
+
|
|
326
|
+
# Test setting a valid ID
|
|
327
|
+
task.default_run_config_id = "123456789012"
|
|
328
|
+
assert task.default_run_config_id == "123456789012"
|
|
329
|
+
|
|
330
|
+
# Test setting back to None
|
|
331
|
+
task.default_run_config_id = None
|
|
332
|
+
assert task.default_run_config_id is None
|