kiln-ai 0.20.1__py3-none-any.whl → 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (117) hide show
  1. kiln_ai/adapters/__init__.py +6 -0
  2. kiln_ai/adapters/adapter_registry.py +43 -226
  3. kiln_ai/adapters/chunkers/__init__.py +13 -0
  4. kiln_ai/adapters/chunkers/base_chunker.py +42 -0
  5. kiln_ai/adapters/chunkers/chunker_registry.py +16 -0
  6. kiln_ai/adapters/chunkers/fixed_window_chunker.py +39 -0
  7. kiln_ai/adapters/chunkers/helpers.py +23 -0
  8. kiln_ai/adapters/chunkers/test_base_chunker.py +63 -0
  9. kiln_ai/adapters/chunkers/test_chunker_registry.py +28 -0
  10. kiln_ai/adapters/chunkers/test_fixed_window_chunker.py +346 -0
  11. kiln_ai/adapters/chunkers/test_helpers.py +75 -0
  12. kiln_ai/adapters/data_gen/test_data_gen_task.py +9 -3
  13. kiln_ai/adapters/embedding/__init__.py +0 -0
  14. kiln_ai/adapters/embedding/base_embedding_adapter.py +44 -0
  15. kiln_ai/adapters/embedding/embedding_registry.py +32 -0
  16. kiln_ai/adapters/embedding/litellm_embedding_adapter.py +199 -0
  17. kiln_ai/adapters/embedding/test_base_embedding_adapter.py +283 -0
  18. kiln_ai/adapters/embedding/test_embedding_registry.py +166 -0
  19. kiln_ai/adapters/embedding/test_litellm_embedding_adapter.py +1149 -0
  20. kiln_ai/adapters/eval/eval_runner.py +6 -2
  21. kiln_ai/adapters/eval/test_base_eval.py +1 -3
  22. kiln_ai/adapters/eval/test_g_eval.py +1 -1
  23. kiln_ai/adapters/extractors/__init__.py +18 -0
  24. kiln_ai/adapters/extractors/base_extractor.py +72 -0
  25. kiln_ai/adapters/extractors/encoding.py +20 -0
  26. kiln_ai/adapters/extractors/extractor_registry.py +44 -0
  27. kiln_ai/adapters/extractors/extractor_runner.py +112 -0
  28. kiln_ai/adapters/extractors/litellm_extractor.py +386 -0
  29. kiln_ai/adapters/extractors/test_base_extractor.py +244 -0
  30. kiln_ai/adapters/extractors/test_encoding.py +54 -0
  31. kiln_ai/adapters/extractors/test_extractor_registry.py +181 -0
  32. kiln_ai/adapters/extractors/test_extractor_runner.py +181 -0
  33. kiln_ai/adapters/extractors/test_litellm_extractor.py +1192 -0
  34. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +2 -2
  35. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +2 -6
  36. kiln_ai/adapters/fine_tune/test_together_finetune.py +2 -6
  37. kiln_ai/adapters/ml_embedding_model_list.py +192 -0
  38. kiln_ai/adapters/ml_model_list.py +382 -4
  39. kiln_ai/adapters/model_adapters/litellm_adapter.py +7 -69
  40. kiln_ai/adapters/model_adapters/test_litellm_adapter.py +1 -1
  41. kiln_ai/adapters/model_adapters/test_structured_output.py +3 -1
  42. kiln_ai/adapters/ollama_tools.py +69 -12
  43. kiln_ai/adapters/provider_tools.py +190 -46
  44. kiln_ai/adapters/rag/deduplication.py +49 -0
  45. kiln_ai/adapters/rag/progress.py +252 -0
  46. kiln_ai/adapters/rag/rag_runners.py +844 -0
  47. kiln_ai/adapters/rag/test_deduplication.py +195 -0
  48. kiln_ai/adapters/rag/test_progress.py +785 -0
  49. kiln_ai/adapters/rag/test_rag_runners.py +2376 -0
  50. kiln_ai/adapters/remote_config.py +80 -8
  51. kiln_ai/adapters/test_adapter_registry.py +579 -86
  52. kiln_ai/adapters/test_ml_embedding_model_list.py +429 -0
  53. kiln_ai/adapters/test_ml_model_list.py +212 -0
  54. kiln_ai/adapters/test_ollama_tools.py +340 -1
  55. kiln_ai/adapters/test_prompt_builders.py +1 -1
  56. kiln_ai/adapters/test_provider_tools.py +199 -8
  57. kiln_ai/adapters/test_remote_config.py +551 -56
  58. kiln_ai/adapters/vector_store/__init__.py +1 -0
  59. kiln_ai/adapters/vector_store/base_vector_store_adapter.py +83 -0
  60. kiln_ai/adapters/vector_store/lancedb_adapter.py +389 -0
  61. kiln_ai/adapters/vector_store/test_base_vector_store.py +160 -0
  62. kiln_ai/adapters/vector_store/test_lancedb_adapter.py +1841 -0
  63. kiln_ai/adapters/vector_store/test_vector_store_registry.py +199 -0
  64. kiln_ai/adapters/vector_store/vector_store_registry.py +33 -0
  65. kiln_ai/datamodel/__init__.py +16 -13
  66. kiln_ai/datamodel/basemodel.py +170 -1
  67. kiln_ai/datamodel/chunk.py +158 -0
  68. kiln_ai/datamodel/datamodel_enums.py +27 -0
  69. kiln_ai/datamodel/embedding.py +64 -0
  70. kiln_ai/datamodel/extraction.py +303 -0
  71. kiln_ai/datamodel/project.py +33 -1
  72. kiln_ai/datamodel/rag.py +79 -0
  73. kiln_ai/datamodel/test_attachment.py +649 -0
  74. kiln_ai/datamodel/test_basemodel.py +1 -1
  75. kiln_ai/datamodel/test_chunk_models.py +317 -0
  76. kiln_ai/datamodel/test_dataset_split.py +1 -1
  77. kiln_ai/datamodel/test_embedding_models.py +448 -0
  78. kiln_ai/datamodel/test_eval_model.py +6 -6
  79. kiln_ai/datamodel/test_extraction_chunk.py +206 -0
  80. kiln_ai/datamodel/test_extraction_model.py +470 -0
  81. kiln_ai/datamodel/test_rag.py +641 -0
  82. kiln_ai/datamodel/test_tool_id.py +81 -0
  83. kiln_ai/datamodel/test_vector_store.py +320 -0
  84. kiln_ai/datamodel/tool_id.py +22 -0
  85. kiln_ai/datamodel/vector_store.py +141 -0
  86. kiln_ai/tools/mcp_session_manager.py +4 -1
  87. kiln_ai/tools/rag_tools.py +157 -0
  88. kiln_ai/tools/test_mcp_session_manager.py +1 -1
  89. kiln_ai/tools/test_rag_tools.py +848 -0
  90. kiln_ai/tools/test_tool_registry.py +91 -2
  91. kiln_ai/tools/tool_registry.py +21 -0
  92. kiln_ai/utils/__init__.py +3 -0
  93. kiln_ai/utils/async_job_runner.py +62 -17
  94. kiln_ai/utils/config.py +2 -2
  95. kiln_ai/utils/env.py +15 -0
  96. kiln_ai/utils/filesystem.py +14 -0
  97. kiln_ai/utils/filesystem_cache.py +60 -0
  98. kiln_ai/utils/litellm.py +94 -0
  99. kiln_ai/utils/lock.py +100 -0
  100. kiln_ai/utils/mime_type.py +38 -0
  101. kiln_ai/utils/pdf_utils.py +38 -0
  102. kiln_ai/utils/test_async_job_runner.py +151 -35
  103. kiln_ai/utils/test_env.py +142 -0
  104. kiln_ai/utils/test_filesystem_cache.py +316 -0
  105. kiln_ai/utils/test_litellm.py +206 -0
  106. kiln_ai/utils/test_lock.py +185 -0
  107. kiln_ai/utils/test_mime_type.py +66 -0
  108. kiln_ai/utils/test_pdf_utils.py +73 -0
  109. kiln_ai/utils/test_uuid.py +111 -0
  110. kiln_ai/utils/test_validation.py +524 -0
  111. kiln_ai/utils/uuid.py +9 -0
  112. kiln_ai/utils/validation.py +90 -0
  113. {kiln_ai-0.20.1.dist-info → kiln_ai-0.21.0.dist-info}/METADATA +7 -1
  114. kiln_ai-0.21.0.dist-info/RECORD +211 -0
  115. kiln_ai-0.20.1.dist-info/RECORD +0 -138
  116. {kiln_ai-0.20.1.dist-info → kiln_ai-0.21.0.dist-info}/WHEEL +0 -0
  117. {kiln_ai-0.20.1.dist-info → kiln_ai-0.21.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -4,10 +4,12 @@ from pydantic import BaseModel, ValidationError
4
4
  from kiln_ai.datamodel.tool_id import (
5
5
  MCP_LOCAL_TOOL_ID_PREFIX,
6
6
  MCP_REMOTE_TOOL_ID_PREFIX,
7
+ RAG_TOOL_ID_PREFIX,
7
8
  KilnBuiltInToolId,
8
9
  ToolId,
9
10
  _check_tool_id,
10
11
  mcp_server_and_tool_name_from_id,
12
+ rag_config_id_from_id,
11
13
  )
12
14
 
13
15
 
@@ -113,6 +115,36 @@ class TestCheckToolId:
113
115
  with pytest.raises(ValueError, match="Invalid tool ID"):
114
116
  _check_tool_id("mcp::wrong::server::tool")
115
117
 
118
+ def test_valid_rag_tools(self):
119
+ """Test validation of valid RAG tools."""
120
+ valid_ids = [
121
+ "kiln_tool::rag::config1",
122
+ "kiln_tool::rag::my_rag_config",
123
+ "kiln_tool::rag::test_config_123",
124
+ ]
125
+ for tool_id in valid_ids:
126
+ result = _check_tool_id(tool_id)
127
+ assert result == tool_id
128
+
129
+ def test_invalid_rag_format(self):
130
+ """Test validation fails for invalid RAG tool formats."""
131
+ # These IDs start with the RAG prefix but have invalid formats
132
+ rag_invalid_ids = [
133
+ "kiln_tool::rag::", # Missing config ID
134
+ "kiln_tool::rag::config::extra", # Too many parts
135
+ ]
136
+
137
+ for invalid_id in rag_invalid_ids:
138
+ with pytest.raises(ValueError, match="Invalid RAG tool ID"):
139
+ _check_tool_id(invalid_id)
140
+
141
+ def test_rag_tool_empty_config_id(self):
142
+ """Test that RAG tool with empty config ID is handled properly."""
143
+ # This tests the case where rag_config_id_from_id returns empty string
144
+ # which should trigger line 66 in the source
145
+ with pytest.raises(ValueError, match="Invalid RAG tool ID"):
146
+ _check_tool_id("kiln_tool::rag::")
147
+
116
148
 
117
149
  class TestMcpServerAndToolNameFromId:
118
150
  """Test the mcp_server_and_tool_name_from_id function."""
@@ -197,6 +229,9 @@ class TestToolIdPydanticType:
197
229
  # Local MCP tools
198
230
  "mcp::local::server1::tool1",
199
231
  "mcp::local::my_server::my_tool",
232
+ # RAG tools
233
+ "kiln_tool::rag::config1",
234
+ "kiln_tool::rag::my_rag_config",
200
235
  ]
201
236
 
202
237
  for tool_id in valid_ids:
@@ -212,6 +247,8 @@ class TestToolIdPydanticType:
212
247
  "mcp::remote::server",
213
248
  "mcp::local::",
214
249
  "mcp::local::server",
250
+ "kiln_tool::rag::",
251
+ "kiln_tool::rag::config::extra",
215
252
  ]
216
253
 
217
254
  for invalid_id in invalid_ids:
@@ -237,3 +274,47 @@ class TestConstants:
237
274
  def test_mcp_local_tool_id_prefix(self):
238
275
  """Test the MCP local tool ID prefix constant."""
239
276
  assert MCP_LOCAL_TOOL_ID_PREFIX == "mcp::local::"
277
+
278
+ def test_rag_tool_id_prefix(self):
279
+ """Test the RAG tool ID prefix constant."""
280
+ assert RAG_TOOL_ID_PREFIX == "kiln_tool::rag::"
281
+
282
+
283
+ class TestRagConfigIdFromId:
284
+ """Test the rag_config_id_from_id function."""
285
+
286
+ def test_valid_rag_ids(self):
287
+ """Test parsing valid RAG tool IDs."""
288
+ test_cases = [
289
+ ("kiln_tool::rag::config1", "config1"),
290
+ ("kiln_tool::rag::my_rag_config", "my_rag_config"),
291
+ ("kiln_tool::rag::test_config_123", "test_config_123"),
292
+ ("kiln_tool::rag::a", "a"), # Minimal valid case
293
+ ]
294
+
295
+ for tool_id, expected in test_cases:
296
+ result = rag_config_id_from_id(tool_id)
297
+ assert result == expected
298
+
299
+ def test_invalid_rag_ids(self):
300
+ """Test parsing fails for invalid RAG tool IDs."""
301
+ # Test various invalid formats that should trigger line 104
302
+ invalid_ids = [
303
+ "kiln_tool::rag::config::extra", # Too many parts (4 parts)
304
+ "wrong::rag::config", # Wrong prefix
305
+ "kiln_tool::wrong::config", # Wrong middle part
306
+ "rag::config", # Too few parts (2 parts)
307
+ "", # Empty string
308
+ "single_part", # Only 1 part
309
+ ]
310
+
311
+ for invalid_id in invalid_ids:
312
+ with pytest.raises(ValueError, match="Invalid RAG tool ID"):
313
+ rag_config_id_from_id(invalid_id)
314
+
315
+ def test_rag_id_with_empty_config_id(self):
316
+ """Test that RAG tool ID with empty config ID returns empty string."""
317
+ # This is actually valid according to the parser - it returns empty string
318
+ # The validation for empty config ID happens in _check_tool_id
319
+ result = rag_config_id_from_id("kiln_tool::rag::")
320
+ assert result == ""
@@ -0,0 +1,320 @@
1
+ import pytest
2
+ from pydantic import ValidationError
3
+
4
+ from kiln_ai.datamodel.project import Project
5
+ from kiln_ai.datamodel.vector_store import (
6
+ LanceDBConfigBaseProperties,
7
+ VectorStoreConfig,
8
+ VectorStoreType,
9
+ )
10
+
11
+
12
+ @pytest.fixture
13
+ def mock_project(tmp_path):
14
+ project_path = tmp_path / "test_project" / "project.kiln"
15
+ project_path.parent.mkdir()
16
+
17
+ project = Project(name="Test Project", path=project_path)
18
+ project.save_to_file()
19
+
20
+ return project
21
+
22
+
23
+ @pytest.fixture
24
+ def mock_vector_store_fts_config_properties():
25
+ return {
26
+ "similarity_top_k": 10,
27
+ "overfetch_factor": 2,
28
+ "vector_column_name": "vector",
29
+ "text_key": "text",
30
+ "doc_id_key": "doc_id",
31
+ }
32
+
33
+
34
+ @pytest.fixture
35
+ def mock_vector_store_vector_config_properties():
36
+ return {
37
+ "similarity_top_k": 10,
38
+ "overfetch_factor": 2,
39
+ "vector_column_name": "vector",
40
+ "text_key": "text",
41
+ "doc_id_key": "doc_id",
42
+ "nprobes": 1,
43
+ }
44
+
45
+
46
+ class TestVectorStoreType:
47
+ def test_vector_store_type_values(self):
48
+ """Test that VectorStoreType enum has expected values."""
49
+ assert VectorStoreType.LANCE_DB_FTS == "lancedb_fts"
50
+ assert VectorStoreType.LANCE_DB_HYBRID == "lancedb_hybrid"
51
+ assert VectorStoreType.LANCE_DB_VECTOR == "lancedb_vector"
52
+
53
+
54
+ class TestLanceDBConfigBaseProperties:
55
+ def test_valid_lance_db_config_base_properties(self):
56
+ """Test creating valid LanceDBConfigBaseProperties."""
57
+ config = LanceDBConfigBaseProperties(
58
+ similarity_top_k=10,
59
+ overfetch_factor=2,
60
+ vector_column_name="vector",
61
+ text_key="text",
62
+ doc_id_key="doc_id",
63
+ nprobes=1,
64
+ )
65
+
66
+ assert config.similarity_top_k == 10
67
+ assert config.overfetch_factor == 2
68
+ assert config.vector_column_name == "vector"
69
+ assert config.text_key == "text"
70
+ assert config.doc_id_key == "doc_id"
71
+ assert config.nprobes == 1
72
+
73
+ def test_lance_db_config_base_properties_without_nprobes(self):
74
+ """Test creating LanceDBConfigBaseProperties without nprobes."""
75
+ config = LanceDBConfigBaseProperties(
76
+ similarity_top_k=10,
77
+ overfetch_factor=2,
78
+ vector_column_name="vector",
79
+ text_key="text",
80
+ doc_id_key="doc_id",
81
+ )
82
+
83
+ assert config.similarity_top_k == 10
84
+ assert config.nprobes is None
85
+
86
+
87
+ class TestVectorStoreConfig:
88
+ def test_invalid_store_type(self, mock_vector_store_fts_config_properties):
89
+ """Test creating VectorStoreConfig with invalid store type."""
90
+ with pytest.raises(ValidationError, match="Input should be"):
91
+ VectorStoreConfig(
92
+ name="test_store",
93
+ store_type="invalid_type", # type: ignore
94
+ properties=mock_vector_store_fts_config_properties,
95
+ )
96
+
97
+ def test_invalid_store_type_after_creation(
98
+ self, mock_vector_store_fts_config_properties
99
+ ):
100
+ """Test creating VectorStoreConfig with invalid store type after creation."""
101
+ config = VectorStoreConfig(
102
+ name="test_store",
103
+ store_type=VectorStoreType.LANCE_DB_FTS,
104
+ properties=mock_vector_store_fts_config_properties,
105
+ )
106
+ with pytest.raises(ValidationError, match="Input should be"):
107
+ config.store_type = "invalid_type" # type: ignore
108
+
109
+ def test_valid_lance_db_fts_vector_store_config(
110
+ self, mock_vector_store_fts_config_properties
111
+ ):
112
+ """Test creating valid VectorStoreConfig with LanceDB FTS."""
113
+ config = VectorStoreConfig(
114
+ name="test_store",
115
+ store_type=VectorStoreType.LANCE_DB_FTS,
116
+ properties=mock_vector_store_fts_config_properties,
117
+ )
118
+
119
+ assert config.name == "test_store"
120
+ assert config.store_type == VectorStoreType.LANCE_DB_FTS
121
+ assert config.properties["similarity_top_k"] == 10
122
+ assert config.properties["overfetch_factor"] == 2
123
+ assert config.properties["vector_column_name"] == "vector"
124
+ assert config.properties["text_key"] == "text"
125
+ assert config.properties["doc_id_key"] == "doc_id"
126
+
127
+ def test_valid_lance_db_vector_store_config(
128
+ self, mock_vector_store_vector_config_properties
129
+ ):
130
+ """Test creating valid VectorStoreConfig with LanceDB Vector."""
131
+ config = VectorStoreConfig(
132
+ name="test_store",
133
+ store_type=VectorStoreType.LANCE_DB_VECTOR,
134
+ properties=mock_vector_store_vector_config_properties,
135
+ )
136
+
137
+ assert config.name == "test_store"
138
+ assert config.store_type == VectorStoreType.LANCE_DB_VECTOR
139
+ assert config.properties["similarity_top_k"] == 10
140
+ assert config.properties["nprobes"] == 1
141
+
142
+ def test_valid_lance_db_hybrid_store_config(
143
+ self, mock_vector_store_vector_config_properties
144
+ ):
145
+ """Test creating valid VectorStoreConfig with LanceDB Hybrid."""
146
+ config = VectorStoreConfig(
147
+ name="test_store",
148
+ store_type=VectorStoreType.LANCE_DB_HYBRID,
149
+ properties=mock_vector_store_vector_config_properties,
150
+ )
151
+
152
+ assert config.name == "test_store"
153
+ assert config.store_type == VectorStoreType.LANCE_DB_HYBRID
154
+ assert config.properties["nprobes"] == 1
155
+
156
+ def test_vector_store_config_missing_required_property(
157
+ self, mock_vector_store_fts_config_properties
158
+ ):
159
+ """Test VectorStoreConfig validation fails when required property is missing."""
160
+ mock_vector_store_fts_config_properties.pop("similarity_top_k")
161
+ with pytest.raises(
162
+ ValidationError,
163
+ match=r".*similarity_top_k is a required property",
164
+ ):
165
+ VectorStoreConfig(
166
+ name="test_store",
167
+ store_type=VectorStoreType.LANCE_DB_FTS,
168
+ properties=mock_vector_store_fts_config_properties,
169
+ )
170
+
171
+ def test_vector_store_config_invalid_property_type(
172
+ self, mock_vector_store_fts_config_properties
173
+ ):
174
+ """Test VectorStoreConfig validation fails when property has wrong type."""
175
+ mock_vector_store_fts_config_properties["similarity_top_k"] = "not_an_int"
176
+ with pytest.raises(
177
+ ValidationError,
178
+ match=r".*similarity_top_k must be of type",
179
+ ):
180
+ VectorStoreConfig(
181
+ name="test_store",
182
+ store_type=VectorStoreType.LANCE_DB_FTS,
183
+ properties=mock_vector_store_fts_config_properties,
184
+ )
185
+
186
+ def test_vector_store_config_fts_missing_nprobes_is_valid(
187
+ self, mock_vector_store_fts_config_properties
188
+ ):
189
+ """Test VectorStoreConfig with FTS type doesn't require nprobes."""
190
+ config = VectorStoreConfig(
191
+ name="test_store",
192
+ store_type=VectorStoreType.LANCE_DB_FTS,
193
+ properties=mock_vector_store_fts_config_properties,
194
+ )
195
+ assert config.store_type == VectorStoreType.LANCE_DB_FTS
196
+
197
+ def test_vector_store_config_vector_missing_nprobes_fails(
198
+ self, mock_vector_store_vector_config_properties
199
+ ):
200
+ """Test VectorStoreConfig with VECTOR type requires nprobes."""
201
+ mock_vector_store_vector_config_properties.pop("nprobes")
202
+ with pytest.raises(
203
+ ValidationError,
204
+ match=r".*nprobes is a required property",
205
+ ):
206
+ VectorStoreConfig(
207
+ name="test_store",
208
+ store_type=VectorStoreType.LANCE_DB_VECTOR,
209
+ properties=mock_vector_store_vector_config_properties,
210
+ )
211
+
212
+ def test_lancedb_properties(self, mock_vector_store_vector_config_properties):
213
+ """Test lancedb_properties method returns correct LanceDBConfigBaseProperties."""
214
+ config = VectorStoreConfig(
215
+ name="test_store",
216
+ store_type=VectorStoreType.LANCE_DB_VECTOR,
217
+ properties=mock_vector_store_vector_config_properties,
218
+ )
219
+
220
+ props = config.lancedb_properties
221
+
222
+ assert isinstance(props, LanceDBConfigBaseProperties)
223
+ assert props.similarity_top_k == 10
224
+ assert props.overfetch_factor == 2
225
+ assert props.vector_column_name == "vector"
226
+ assert props.text_key == "text"
227
+ assert props.doc_id_key == "doc_id"
228
+ assert props.nprobes == 1
229
+
230
+ def test_vector_store_config_inherits_from_kiln_parented_model(
231
+ self, mock_vector_store_fts_config_properties
232
+ ):
233
+ """Test that VectorStoreConfig inherits from KilnParentedModel."""
234
+ config = VectorStoreConfig(
235
+ name="test_store",
236
+ store_type=VectorStoreType.LANCE_DB_FTS,
237
+ properties=mock_vector_store_fts_config_properties,
238
+ )
239
+
240
+ # Check that it has the expected base fields
241
+ assert hasattr(config, "id")
242
+ assert hasattr(config, "v")
243
+ assert hasattr(config, "created_at")
244
+ assert hasattr(config, "created_by")
245
+ assert hasattr(config, "parent")
246
+
247
+ @pytest.mark.parametrize(
248
+ "name",
249
+ ["valid_name", "valid name", "valid-name", "valid_name_123", "VALID_NAME"],
250
+ )
251
+ def test_vector_store_config_valid_names(
252
+ self, name, mock_vector_store_fts_config_properties
253
+ ):
254
+ """Test VectorStoreConfig accepts valid names."""
255
+ config = VectorStoreConfig(
256
+ name=name,
257
+ store_type=VectorStoreType.LANCE_DB_FTS,
258
+ properties=mock_vector_store_fts_config_properties,
259
+ )
260
+ assert config.name == name
261
+
262
+ @pytest.mark.parametrize(
263
+ "name",
264
+ [
265
+ "",
266
+ "a" * 121, # Too long
267
+ ],
268
+ )
269
+ def test_vector_store_config_invalid_names(
270
+ self, name, mock_vector_store_fts_config_properties
271
+ ):
272
+ """Test VectorStoreConfig rejects invalid names."""
273
+ with pytest.raises(ValidationError):
274
+ VectorStoreConfig(
275
+ name=name,
276
+ store_type=VectorStoreType.LANCE_DB_FTS,
277
+ properties=mock_vector_store_fts_config_properties,
278
+ )
279
+
280
+ def test_parent_project(
281
+ self, mock_project, mock_vector_store_fts_config_properties
282
+ ):
283
+ """Test that parent project is returned correctly."""
284
+ config = VectorStoreConfig(
285
+ name="test_store",
286
+ store_type=VectorStoreType.LANCE_DB_FTS,
287
+ properties=mock_vector_store_fts_config_properties,
288
+ parent=mock_project,
289
+ )
290
+
291
+ assert config.parent_project() is mock_project
292
+
293
+ def test_vector_store_config_parent_project_none(
294
+ self, mock_vector_store_fts_config_properties
295
+ ):
296
+ """Test that parent project is None if not set."""
297
+ config = VectorStoreConfig(
298
+ name="test_store",
299
+ store_type=VectorStoreType.LANCE_DB_FTS,
300
+ properties=mock_vector_store_fts_config_properties,
301
+ )
302
+
303
+ assert config.parent_project() is None
304
+
305
+ def test_project_has_vector_store_configs(
306
+ self, mock_project, mock_vector_store_fts_config_properties
307
+ ):
308
+ """Test that project has vector store configs."""
309
+ config = VectorStoreConfig(
310
+ name="test_store",
311
+ store_type=VectorStoreType.LANCE_DB_FTS,
312
+ properties=mock_vector_store_fts_config_properties,
313
+ parent=mock_project,
314
+ )
315
+ config.save_to_file()
316
+
317
+ assert len(mock_project.vector_store_configs(readonly=True)) == 1
318
+ assert config.id in [
319
+ vc.id for vc in mock_project.vector_store_configs(readonly=True)
320
+ ]
@@ -26,6 +26,7 @@ class KilnBuiltInToolId(str, Enum):
26
26
 
27
27
 
28
28
  MCP_REMOTE_TOOL_ID_PREFIX = "mcp::remote::"
29
+ RAG_TOOL_ID_PREFIX = "kiln_tool::rag::"
29
30
  MCP_LOCAL_TOOL_ID_PREFIX = "mcp::local::"
30
31
 
31
32
 
@@ -58,6 +59,15 @@ def _check_tool_id(id: str) -> str:
58
59
  )
59
60
  return id
60
61
 
62
+ # RAG tools must have format: kiln_tool::rag::<rag_config_id>
63
+ if id.startswith(RAG_TOOL_ID_PREFIX):
64
+ rag_config_id = rag_config_id_from_id(id)
65
+ if not rag_config_id:
66
+ raise ValueError(
67
+ f"Invalid RAG tool ID: {id}. Expected format: 'kiln_tool::rag::<rag_config_id>'."
68
+ )
69
+ return id
70
+
61
71
  raise ValueError(f"Invalid tool ID: {id}")
62
72
 
63
73
 
@@ -81,3 +91,15 @@ def mcp_server_and_tool_name_from_id(id: str) -> tuple[str, str]:
81
91
  f"Invalid MCP tool ID: {id}. Expected format: 'mcp::(remote|local)::<server_id>::<tool_name>'."
82
92
  )
83
93
  return parts[2], parts[3] # server_id, tool_name
94
+
95
+
96
+ def rag_config_id_from_id(id: str) -> str:
97
+ """
98
+ Get the RAG config ID from the ID.
99
+ """
100
+ parts = id.split("::")
101
+ if not id.startswith(RAG_TOOL_ID_PREFIX) or len(parts) != 3:
102
+ raise ValueError(
103
+ f"Invalid RAG tool ID: {id}. Expected format: 'kiln_tool::rag::<rag_config_id>'."
104
+ )
105
+ return parts[2]
@@ -0,0 +1,141 @@
1
+ from enum import Enum
2
+ from typing import TYPE_CHECKING, Union
3
+
4
+ from pydantic import BaseModel, Field, model_validator
5
+
6
+ from kiln_ai.datamodel.basemodel import FilenameString, KilnParentedModel
7
+ from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
8
+ from kiln_ai.utils.validation import (
9
+ validate_return_dict_prop,
10
+ validate_return_dict_prop_optional,
11
+ )
12
+
13
+ if TYPE_CHECKING:
14
+ from kiln_ai.datamodel.project import Project
15
+
16
+
17
+ class VectorStoreType(str, Enum):
18
+ LANCE_DB_FTS = "lancedb_fts"
19
+ LANCE_DB_HYBRID = "lancedb_hybrid"
20
+ LANCE_DB_VECTOR = "lancedb_vector"
21
+
22
+
23
+ class LanceDBConfigBaseProperties(BaseModel):
24
+ similarity_top_k: int = Field(
25
+ description="The number of results to return from the vector store.",
26
+ )
27
+ overfetch_factor: int = Field(
28
+ description="The overfetch factor to use for the vector search.",
29
+ )
30
+ vector_column_name: str = Field(
31
+ description="The name of the vector column in the vector store.",
32
+ )
33
+ text_key: str = Field(
34
+ description="The name of the text column in the vector store.",
35
+ )
36
+ doc_id_key: str = Field(
37
+ description="The name of the document id column in the vector store.",
38
+ )
39
+ nprobes: int | None = Field(
40
+ description="The number of probes to use for the vector search.",
41
+ default=None,
42
+ )
43
+
44
+
45
+ class VectorStoreConfig(KilnParentedModel):
46
+ name: FilenameString = Field(
47
+ description="A name for your own reference to identify the vector store config.",
48
+ )
49
+ description: str | None = Field(
50
+ description="A description for your own reference.",
51
+ default=None,
52
+ )
53
+ store_type: VectorStoreType = Field(
54
+ description="The type of vector store to use.",
55
+ )
56
+ properties: dict[str, str | int | float | None] = Field(
57
+ description="The properties of the vector store config, specific to the selected store_type.",
58
+ )
59
+
60
+ @model_validator(mode="after")
61
+ def validate_properties(self):
62
+ match self.store_type:
63
+ case (
64
+ VectorStoreType.LANCE_DB_FTS
65
+ | VectorStoreType.LANCE_DB_HYBRID
66
+ | VectorStoreType.LANCE_DB_VECTOR
67
+ ):
68
+ return self.validate_lancedb_properties(self.store_type)
69
+ case _:
70
+ raise_exhaustive_enum_error(self.store_type)
71
+
72
+ def validate_lancedb_properties(self, store_type: VectorStoreType):
73
+ err_msg_prefix = f"LanceDB vector store configs properties for {store_type}:"
74
+ validate_return_dict_prop(
75
+ self.properties, "similarity_top_k", int, err_msg_prefix
76
+ )
77
+ validate_return_dict_prop(
78
+ self.properties, "overfetch_factor", int, err_msg_prefix
79
+ )
80
+ validate_return_dict_prop(
81
+ self.properties, "vector_column_name", str, err_msg_prefix
82
+ )
83
+ validate_return_dict_prop(self.properties, "text_key", str, err_msg_prefix)
84
+ validate_return_dict_prop(self.properties, "doc_id_key", str, err_msg_prefix)
85
+
86
+ # nprobes is only used for vector and hybrid queries
87
+ if (
88
+ store_type == VectorStoreType.LANCE_DB_VECTOR
89
+ or store_type == VectorStoreType.LANCE_DB_HYBRID
90
+ ):
91
+ validate_return_dict_prop(self.properties, "nprobes", int, err_msg_prefix)
92
+
93
+ return self
94
+
95
+ @property
96
+ def lancedb_properties(self) -> LanceDBConfigBaseProperties:
97
+ err_msg_prefix = "LanceDB vector store configs properties:"
98
+ return LanceDBConfigBaseProperties(
99
+ similarity_top_k=validate_return_dict_prop(
100
+ self.properties,
101
+ "similarity_top_k",
102
+ int,
103
+ err_msg_prefix,
104
+ ),
105
+ overfetch_factor=validate_return_dict_prop(
106
+ self.properties,
107
+ "overfetch_factor",
108
+ int,
109
+ err_msg_prefix,
110
+ ),
111
+ vector_column_name=validate_return_dict_prop(
112
+ self.properties,
113
+ "vector_column_name",
114
+ str,
115
+ err_msg_prefix,
116
+ ),
117
+ text_key=validate_return_dict_prop(
118
+ self.properties,
119
+ "text_key",
120
+ str,
121
+ err_msg_prefix,
122
+ ),
123
+ doc_id_key=validate_return_dict_prop(
124
+ self.properties,
125
+ "doc_id_key",
126
+ str,
127
+ err_msg_prefix,
128
+ ),
129
+ nprobes=validate_return_dict_prop_optional(
130
+ self.properties,
131
+ "nprobes",
132
+ int,
133
+ err_msg_prefix,
134
+ ),
135
+ )
136
+
137
+ # Workaround to return typed parent without importing Project
138
+ def parent_project(self) -> Union["Project", None]:
139
+ if self.parent is None or self.parent.__class__.__name__ != "Project":
140
+ return None
141
+ return self.parent # type: ignore
@@ -3,6 +3,7 @@ import os
3
3
  import subprocess
4
4
  import sys
5
5
  from contextlib import asynccontextmanager
6
+ from datetime import timedelta
6
7
  from typing import AsyncGenerator
7
8
 
8
9
  import httpx
@@ -171,7 +172,9 @@ class MCPSessionManager:
171
172
 
172
173
  try:
173
174
  async with stdio_client(server_params) as (read, write):
174
- async with ClientSession(read, write) as session:
175
+ async with ClientSession(
176
+ read, write, read_timeout_seconds=timedelta(seconds=8)
177
+ ) as session:
175
178
  await session.initialize()
176
179
  yield session
177
180
  except Exception as e: