kiln-ai 0.20.1__py3-none-any.whl → 0.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (133) hide show
  1. kiln_ai/adapters/__init__.py +6 -0
  2. kiln_ai/adapters/adapter_registry.py +43 -226
  3. kiln_ai/adapters/chunkers/__init__.py +13 -0
  4. kiln_ai/adapters/chunkers/base_chunker.py +42 -0
  5. kiln_ai/adapters/chunkers/chunker_registry.py +16 -0
  6. kiln_ai/adapters/chunkers/fixed_window_chunker.py +39 -0
  7. kiln_ai/adapters/chunkers/helpers.py +23 -0
  8. kiln_ai/adapters/chunkers/test_base_chunker.py +63 -0
  9. kiln_ai/adapters/chunkers/test_chunker_registry.py +28 -0
  10. kiln_ai/adapters/chunkers/test_fixed_window_chunker.py +346 -0
  11. kiln_ai/adapters/chunkers/test_helpers.py +75 -0
  12. kiln_ai/adapters/data_gen/test_data_gen_task.py +9 -3
  13. kiln_ai/adapters/embedding/__init__.py +0 -0
  14. kiln_ai/adapters/embedding/base_embedding_adapter.py +44 -0
  15. kiln_ai/adapters/embedding/embedding_registry.py +32 -0
  16. kiln_ai/adapters/embedding/litellm_embedding_adapter.py +199 -0
  17. kiln_ai/adapters/embedding/test_base_embedding_adapter.py +283 -0
  18. kiln_ai/adapters/embedding/test_embedding_registry.py +166 -0
  19. kiln_ai/adapters/embedding/test_litellm_embedding_adapter.py +1149 -0
  20. kiln_ai/adapters/eval/eval_runner.py +6 -2
  21. kiln_ai/adapters/eval/test_base_eval.py +1 -3
  22. kiln_ai/adapters/eval/test_g_eval.py +1 -1
  23. kiln_ai/adapters/extractors/__init__.py +18 -0
  24. kiln_ai/adapters/extractors/base_extractor.py +72 -0
  25. kiln_ai/adapters/extractors/encoding.py +20 -0
  26. kiln_ai/adapters/extractors/extractor_registry.py +44 -0
  27. kiln_ai/adapters/extractors/extractor_runner.py +112 -0
  28. kiln_ai/adapters/extractors/litellm_extractor.py +406 -0
  29. kiln_ai/adapters/extractors/test_base_extractor.py +244 -0
  30. kiln_ai/adapters/extractors/test_encoding.py +54 -0
  31. kiln_ai/adapters/extractors/test_extractor_registry.py +181 -0
  32. kiln_ai/adapters/extractors/test_extractor_runner.py +181 -0
  33. kiln_ai/adapters/extractors/test_litellm_extractor.py +1290 -0
  34. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +2 -2
  35. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +2 -6
  36. kiln_ai/adapters/fine_tune/test_together_finetune.py +2 -6
  37. kiln_ai/adapters/ml_embedding_model_list.py +494 -0
  38. kiln_ai/adapters/ml_model_list.py +876 -18
  39. kiln_ai/adapters/model_adapters/litellm_adapter.py +40 -75
  40. kiln_ai/adapters/model_adapters/test_litellm_adapter.py +79 -1
  41. kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +119 -5
  42. kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +9 -3
  43. kiln_ai/adapters/model_adapters/test_structured_output.py +9 -10
  44. kiln_ai/adapters/ollama_tools.py +69 -12
  45. kiln_ai/adapters/provider_tools.py +190 -46
  46. kiln_ai/adapters/rag/deduplication.py +49 -0
  47. kiln_ai/adapters/rag/progress.py +252 -0
  48. kiln_ai/adapters/rag/rag_runners.py +844 -0
  49. kiln_ai/adapters/rag/test_deduplication.py +195 -0
  50. kiln_ai/adapters/rag/test_progress.py +785 -0
  51. kiln_ai/adapters/rag/test_rag_runners.py +2376 -0
  52. kiln_ai/adapters/remote_config.py +80 -8
  53. kiln_ai/adapters/test_adapter_registry.py +579 -86
  54. kiln_ai/adapters/test_ml_embedding_model_list.py +239 -0
  55. kiln_ai/adapters/test_ml_model_list.py +202 -0
  56. kiln_ai/adapters/test_ollama_tools.py +340 -1
  57. kiln_ai/adapters/test_prompt_builders.py +1 -1
  58. kiln_ai/adapters/test_provider_tools.py +199 -8
  59. kiln_ai/adapters/test_remote_config.py +551 -56
  60. kiln_ai/adapters/vector_store/__init__.py +1 -0
  61. kiln_ai/adapters/vector_store/base_vector_store_adapter.py +83 -0
  62. kiln_ai/adapters/vector_store/lancedb_adapter.py +389 -0
  63. kiln_ai/adapters/vector_store/test_base_vector_store.py +160 -0
  64. kiln_ai/adapters/vector_store/test_lancedb_adapter.py +1841 -0
  65. kiln_ai/adapters/vector_store/test_vector_store_registry.py +199 -0
  66. kiln_ai/adapters/vector_store/vector_store_registry.py +33 -0
  67. kiln_ai/datamodel/__init__.py +16 -13
  68. kiln_ai/datamodel/basemodel.py +201 -4
  69. kiln_ai/datamodel/chunk.py +158 -0
  70. kiln_ai/datamodel/datamodel_enums.py +27 -0
  71. kiln_ai/datamodel/embedding.py +64 -0
  72. kiln_ai/datamodel/external_tool_server.py +206 -54
  73. kiln_ai/datamodel/extraction.py +317 -0
  74. kiln_ai/datamodel/project.py +33 -1
  75. kiln_ai/datamodel/rag.py +79 -0
  76. kiln_ai/datamodel/task.py +5 -0
  77. kiln_ai/datamodel/task_output.py +41 -11
  78. kiln_ai/datamodel/test_attachment.py +649 -0
  79. kiln_ai/datamodel/test_basemodel.py +270 -14
  80. kiln_ai/datamodel/test_chunk_models.py +317 -0
  81. kiln_ai/datamodel/test_dataset_split.py +1 -1
  82. kiln_ai/datamodel/test_datasource.py +50 -0
  83. kiln_ai/datamodel/test_embedding_models.py +448 -0
  84. kiln_ai/datamodel/test_eval_model.py +6 -6
  85. kiln_ai/datamodel/test_external_tool_server.py +534 -152
  86. kiln_ai/datamodel/test_extraction_chunk.py +206 -0
  87. kiln_ai/datamodel/test_extraction_model.py +501 -0
  88. kiln_ai/datamodel/test_rag.py +641 -0
  89. kiln_ai/datamodel/test_task.py +35 -1
  90. kiln_ai/datamodel/test_tool_id.py +187 -1
  91. kiln_ai/datamodel/test_vector_store.py +320 -0
  92. kiln_ai/datamodel/tool_id.py +58 -0
  93. kiln_ai/datamodel/vector_store.py +141 -0
  94. kiln_ai/tools/base_tool.py +12 -3
  95. kiln_ai/tools/built_in_tools/math_tools.py +12 -4
  96. kiln_ai/tools/kiln_task_tool.py +158 -0
  97. kiln_ai/tools/mcp_server_tool.py +2 -2
  98. kiln_ai/tools/mcp_session_manager.py +51 -22
  99. kiln_ai/tools/rag_tools.py +164 -0
  100. kiln_ai/tools/test_kiln_task_tool.py +527 -0
  101. kiln_ai/tools/test_mcp_server_tool.py +4 -15
  102. kiln_ai/tools/test_mcp_session_manager.py +187 -227
  103. kiln_ai/tools/test_rag_tools.py +929 -0
  104. kiln_ai/tools/test_tool_registry.py +290 -7
  105. kiln_ai/tools/tool_registry.py +69 -16
  106. kiln_ai/utils/__init__.py +3 -0
  107. kiln_ai/utils/async_job_runner.py +62 -17
  108. kiln_ai/utils/config.py +2 -2
  109. kiln_ai/utils/env.py +15 -0
  110. kiln_ai/utils/filesystem.py +14 -0
  111. kiln_ai/utils/filesystem_cache.py +60 -0
  112. kiln_ai/utils/litellm.py +94 -0
  113. kiln_ai/utils/lock.py +100 -0
  114. kiln_ai/utils/mime_type.py +38 -0
  115. kiln_ai/utils/open_ai_types.py +19 -2
  116. kiln_ai/utils/pdf_utils.py +59 -0
  117. kiln_ai/utils/test_async_job_runner.py +151 -35
  118. kiln_ai/utils/test_env.py +142 -0
  119. kiln_ai/utils/test_filesystem_cache.py +316 -0
  120. kiln_ai/utils/test_litellm.py +206 -0
  121. kiln_ai/utils/test_lock.py +185 -0
  122. kiln_ai/utils/test_mime_type.py +66 -0
  123. kiln_ai/utils/test_open_ai_types.py +88 -12
  124. kiln_ai/utils/test_pdf_utils.py +86 -0
  125. kiln_ai/utils/test_uuid.py +111 -0
  126. kiln_ai/utils/test_validation.py +524 -0
  127. kiln_ai/utils/uuid.py +9 -0
  128. kiln_ai/utils/validation.py +90 -0
  129. {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/METADATA +9 -1
  130. kiln_ai-0.22.0.dist-info/RECORD +213 -0
  131. kiln_ai-0.20.1.dist-info/RECORD +0 -138
  132. {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/WHEEL +0 -0
  133. {kiln_ai-0.20.1.dist-info → kiln_ai-0.22.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -1,15 +1,22 @@
1
+ from pathlib import Path
1
2
  from unittest.mock import Mock
2
3
 
3
4
  import pytest
4
5
 
5
- from kiln_ai.datamodel.external_tool_server import ExternalToolServer, ToolServerType
6
+ from kiln_ai.datamodel.external_tool_server import (
7
+ ExternalToolServer,
8
+ ToolServerType,
9
+ )
6
10
  from kiln_ai.datamodel.project import Project
7
11
  from kiln_ai.datamodel.task import Task
8
12
  from kiln_ai.datamodel.tool_id import (
13
+ KILN_TASK_TOOL_ID_PREFIX,
9
14
  MCP_LOCAL_TOOL_ID_PREFIX,
10
15
  MCP_REMOTE_TOOL_ID_PREFIX,
16
+ RAG_TOOL_ID_PREFIX,
11
17
  KilnBuiltInToolId,
12
18
  _check_tool_id,
19
+ kiln_task_server_id_from_tool_id,
13
20
  mcp_server_and_tool_name_from_id,
14
21
  )
15
22
  from kiln_ai.tools.built_in_tools.math_tools import (
@@ -18,6 +25,7 @@ from kiln_ai.tools.built_in_tools.math_tools import (
18
25
  MultiplyTool,
19
26
  SubtractTool,
20
27
  )
28
+ from kiln_ai.tools.kiln_task_tool import KilnTaskTool
21
29
  from kiln_ai.tools.mcp_server_tool import MCPServerTool
22
30
  from kiln_ai.tools.tool_registry import tool_from_id
23
31
 
@@ -85,7 +93,6 @@ class TestToolRegistry:
85
93
  type=ToolServerType.remote_mcp,
86
94
  properties={
87
95
  "server_url": "https://example.com",
88
- "headers": {},
89
96
  },
90
97
  )
91
98
 
@@ -143,7 +150,7 @@ class TestToolRegistry:
143
150
  tool_id = f"{MCP_LOCAL_TOOL_ID_PREFIX}test_server::test_tool"
144
151
  with pytest.raises(
145
152
  ValueError,
146
- match="Unable to resolve tool from id.*Requires a parent project/task",
153
+ match=r"Unable to resolve tool from id.*Requires a parent project/task",
147
154
  ):
148
155
  tool_from_id(tool_id, task=None)
149
156
 
@@ -155,7 +162,6 @@ class TestToolRegistry:
155
162
  type=ToolServerType.remote_mcp,
156
163
  properties={
157
164
  "server_url": "https://example.com",
158
- "headers": {},
159
165
  },
160
166
  )
161
167
 
@@ -181,6 +187,93 @@ class TestToolRegistry:
181
187
  ):
182
188
  tool_from_id(tool_id, task=mock_task)
183
189
 
190
+ def test_tool_from_id_rag_tool_success(self):
191
+ """Test that tool_from_id works with RAG tool IDs."""
192
+ # Create mock RAG config
193
+ from unittest.mock import patch
194
+
195
+ with (
196
+ patch("kiln_ai.tools.tool_registry.RagConfig") as mock_rag_config_class,
197
+ patch("kiln_ai.tools.rag_tools.RagTool") as mock_rag_tool_class,
198
+ ):
199
+ # Setup mock RAG config
200
+ mock_rag_config = Mock()
201
+ mock_rag_config.id = "test_rag_config"
202
+ mock_rag_config_class.from_id_and_parent_path.return_value = mock_rag_config
203
+
204
+ # Setup mock RAG tool
205
+ mock_rag_tool = Mock()
206
+ mock_rag_tool_class.return_value = mock_rag_tool
207
+
208
+ # Create mock project
209
+ mock_project = Mock(spec=Project)
210
+ mock_project.id = "test_project_id"
211
+ mock_project.path = Path("/test/path")
212
+
213
+ # Create mock task with parent project
214
+ mock_task = Mock(spec=Task)
215
+ mock_task.parent_project.return_value = mock_project
216
+
217
+ # Test with RAG tool ID
218
+ tool_id = f"{RAG_TOOL_ID_PREFIX}test_rag_config"
219
+ tool = tool_from_id(tool_id, task=mock_task)
220
+
221
+ # Verify the tool is RagTool
222
+ assert tool == mock_rag_tool
223
+ mock_rag_config_class.from_id_and_parent_path.assert_called_once_with(
224
+ "test_rag_config", Path("/test/path")
225
+ )
226
+ mock_rag_tool_class.assert_called_once_with(tool_id, mock_rag_config)
227
+
228
+ def test_tool_from_id_rag_tool_no_task(self):
229
+ """Test that RAG tool ID without task raises ValueError."""
230
+ tool_id = f"{RAG_TOOL_ID_PREFIX}test_rag_config"
231
+
232
+ with pytest.raises(
233
+ ValueError,
234
+ match=r"Unable to resolve tool from id.*Requires a parent project/task",
235
+ ):
236
+ tool_from_id(tool_id, task=None)
237
+
238
+ def test_tool_from_id_rag_tool_no_project(self):
239
+ """Test that RAG tool ID with task but no project raises ValueError."""
240
+ # Create mock task without parent project
241
+ mock_task = Mock(spec=Task)
242
+ mock_task.parent_project.return_value = None
243
+
244
+ tool_id = f"{RAG_TOOL_ID_PREFIX}test_rag_config"
245
+
246
+ with pytest.raises(
247
+ ValueError,
248
+ match=r"Unable to resolve tool from id.*Requires a parent project/task",
249
+ ):
250
+ tool_from_id(tool_id, task=mock_task)
251
+
252
+ def test_tool_from_id_rag_config_not_found(self):
253
+ """Test that RAG tool ID with missing RAG config raises ValueError."""
254
+ from unittest.mock import patch
255
+
256
+ with patch("kiln_ai.tools.tool_registry.RagConfig") as mock_rag_config_class:
257
+ # Setup mock to return None (config not found)
258
+ mock_rag_config_class.from_id_and_parent_path.return_value = None
259
+
260
+ # Create mock project
261
+ mock_project = Mock(spec=Project)
262
+ mock_project.id = "test_project_id"
263
+ mock_project.path = Path("/test/path")
264
+
265
+ # Create mock task with parent project
266
+ mock_task = Mock(spec=Task)
267
+ mock_task.parent_project.return_value = mock_project
268
+
269
+ tool_id = f"{RAG_TOOL_ID_PREFIX}missing_rag_config"
270
+
271
+ with pytest.raises(
272
+ ValueError,
273
+ match="RAG config not found: missing_rag_config in project test_project_id for tool",
274
+ ):
275
+ tool_from_id(tool_id, task=mock_task)
276
+
184
277
  def test_all_built_in_tools_are_registered(self):
185
278
  """Test that all KilnBuiltInToolId enum members are handled by the registry."""
186
279
  for tool_id in KilnBuiltInToolId:
@@ -272,6 +365,48 @@ class TestToolRegistry:
272
365
  with pytest.raises(ValueError, match=f"Invalid tool ID: {invalid_id}"):
273
366
  _check_tool_id(invalid_id)
274
367
 
368
+ def test_check_tool_id_valid_kiln_task_tool_id(self):
369
+ """Test that _check_tool_id accepts valid Kiln task tool IDs."""
370
+ valid_kiln_task_ids = [
371
+ f"{KILN_TASK_TOOL_ID_PREFIX}server123",
372
+ f"{KILN_TASK_TOOL_ID_PREFIX}my_task_server",
373
+ f"{KILN_TASK_TOOL_ID_PREFIX}123456789",
374
+ f"{KILN_TASK_TOOL_ID_PREFIX}server_with_underscores",
375
+ f"{KILN_TASK_TOOL_ID_PREFIX}server-with-dashes",
376
+ ]
377
+
378
+ for tool_id in valid_kiln_task_ids:
379
+ result = _check_tool_id(tool_id)
380
+ assert result == tool_id
381
+
382
+ def test_check_tool_id_invalid_kiln_task_tool_id(self):
383
+ """Test that _check_tool_id rejects invalid Kiln task tool IDs."""
384
+ # These start with the prefix but have wrong format
385
+ invalid_kiln_task_format_ids = [
386
+ f"{KILN_TASK_TOOL_ID_PREFIX}", # Missing server ID
387
+ f"{KILN_TASK_TOOL_ID_PREFIX}::", # Empty server ID
388
+ f"{KILN_TASK_TOOL_ID_PREFIX}server::tool", # Too many parts (3 instead of 2)
389
+ f"{KILN_TASK_TOOL_ID_PREFIX}server::tool::extra", # Too many parts (4 instead of 2)
390
+ ]
391
+
392
+ for invalid_id in invalid_kiln_task_format_ids:
393
+ with pytest.raises(
394
+ ValueError, match=f"Invalid Kiln task tool ID format: {invalid_id}"
395
+ ):
396
+ _check_tool_id(invalid_id)
397
+
398
+ # These don't match the prefix - get generic error
399
+ invalid_generic_ids = [
400
+ "kiln_task:", # Missing last colon (doesn't match full prefix)
401
+ "kiln:task::server", # Wrong prefix format
402
+ "kiln_task_server", # Missing colons
403
+ "task::server", # Missing kiln prefix
404
+ ]
405
+
406
+ for invalid_id in invalid_generic_ids:
407
+ with pytest.raises(ValueError, match=f"Invalid tool ID: {invalid_id}"):
408
+ _check_tool_id(invalid_id)
409
+
275
410
  def test_mcp_server_and_tool_name_from_id_valid_inputs(self):
276
411
  """Test that mcp_server_and_tool_name_from_id correctly parses valid MCP tool IDs."""
277
412
  test_cases = [
@@ -400,13 +535,71 @@ class TestToolRegistry:
400
535
  assert server_id == expected_server
401
536
  assert tool_name == expected_tool
402
537
 
538
+ def test_kiln_task_server_id_from_tool_id_valid_inputs(self):
539
+ """Test that kiln_task_server_id_from_tool_id correctly parses valid Kiln task tool IDs."""
540
+ test_cases = [
541
+ ("kiln_task::server123", "server123"),
542
+ ("kiln_task::my_task_server", "my_task_server"),
543
+ ("kiln_task::123456789", "123456789"),
544
+ ("kiln_task::server_with_underscores", "server_with_underscores"),
545
+ ("kiln_task::server-with-dashes", "server-with-dashes"),
546
+ ("kiln_task::a", "a"), # Minimal valid case
547
+ (
548
+ "kiln_task::very_long_server_name_with_numbers_123",
549
+ "very_long_server_name_with_numbers_123",
550
+ ),
551
+ ]
552
+
553
+ for tool_id, expected_server_id in test_cases:
554
+ result = kiln_task_server_id_from_tool_id(tool_id)
555
+ assert result == expected_server_id, (
556
+ f"Failed for {tool_id}: expected {expected_server_id}, got {result}"
557
+ )
558
+
559
+ def test_kiln_task_server_id_from_tool_id_invalid_inputs(self):
560
+ """Test that kiln_task_server_id_from_tool_id raises ValueError for invalid Kiln task tool IDs."""
561
+ invalid_inputs = [
562
+ "kiln_task::", # Empty server ID
563
+ "kiln_task::server::tool", # Too many parts (3 instead of 2)
564
+ "kiln_task::server::tool::extra", # Too many parts (4 instead of 2)
565
+ "invalid::format", # Wrong prefix
566
+ "", # Empty string
567
+ "single_part", # No separators
568
+ "two::parts", # Only 2 parts but wrong prefix
569
+ "kiln_task", # Missing colons
570
+ ]
571
+
572
+ for invalid_id in invalid_inputs:
573
+ with pytest.raises(
574
+ ValueError,
575
+ match=r"Invalid Kiln task tool ID format:.*Expected format.*kiln_task::<server_id>",
576
+ ):
577
+ kiln_task_server_id_from_tool_id(invalid_id)
578
+
579
+ @pytest.mark.parametrize(
580
+ "tool_id,expected_server_id",
581
+ [
582
+ ("kiln_task::test_server", "test_server"),
583
+ ("kiln_task::s", "s"),
584
+ ("kiln_task::long_server_name_123", "long_server_name_123"),
585
+ ("kiln_task::server-with-dashes", "server-with-dashes"),
586
+ ("kiln_task::server_with_underscores", "server_with_underscores"),
587
+ ],
588
+ )
589
+ def test_kiln_task_server_id_from_tool_id_parametrized(
590
+ self, tool_id, expected_server_id
591
+ ):
592
+ """Parametrized test for kiln_task_server_id_from_tool_id with various valid inputs."""
593
+ server_id = kiln_task_server_id_from_tool_id(tool_id)
594
+ assert server_id == expected_server_id
595
+
403
596
  def test_tool_from_id_mcp_missing_task_raises_error(self):
404
597
  """Test that MCP tool ID with missing task raises ValueError."""
405
598
  mcp_tool_id = f"{MCP_REMOTE_TOOL_ID_PREFIX}test_server::test_tool"
406
599
 
407
600
  with pytest.raises(
408
601
  ValueError,
409
- match="Unable to resolve tool from id.*Requires a parent project/task",
602
+ match=r"Unable to resolve tool from id.*Requires a parent project/task",
410
603
  ):
411
604
  tool_from_id(mcp_tool_id, task=None)
412
605
 
@@ -419,7 +612,6 @@ class TestToolRegistry:
419
612
  description="Test MCP server",
420
613
  properties={
421
614
  "server_url": "https://example.com",
422
- "headers": {},
423
615
  },
424
616
  )
425
617
 
@@ -450,7 +642,6 @@ class TestToolRegistry:
450
642
  description="Different MCP server",
451
643
  properties={
452
644
  "server_url": "https://example.com",
453
- "headers": {},
454
645
  },
455
646
  )
456
647
 
@@ -471,3 +662,95 @@ class TestToolRegistry:
471
662
  match="External tool server not found: nonexistent_server in project ID test_project_id",
472
663
  ):
473
664
  tool_from_id(mcp_tool_id, task=mock_task)
665
+
666
+ def test_tool_from_id_kiln_task_tool_success(self):
667
+ """Test that tool_from_id works with Kiln task tool IDs."""
668
+ # Create mock external tool server for Kiln task
669
+ mock_server = ExternalToolServer(
670
+ name="test_kiln_task_server",
671
+ type=ToolServerType.kiln_task,
672
+ description="Test Kiln task server",
673
+ properties={
674
+ "name": "test_task_tool",
675
+ "description": "A test task tool",
676
+ "task_id": "test_task_123",
677
+ "run_config_id": "test_config_456",
678
+ "is_archived": False,
679
+ },
680
+ )
681
+
682
+ # Create mock project with the external tool server
683
+ mock_project = Mock(spec=Project)
684
+ mock_project.id = "test_project_id"
685
+ mock_project.external_tool_servers.return_value = [mock_server]
686
+
687
+ # Create mock task with parent project
688
+ mock_task = Mock(spec=Task)
689
+ mock_task.parent_project.return_value = mock_project
690
+
691
+ # Test with Kiln task tool ID
692
+ tool_id = f"{KILN_TASK_TOOL_ID_PREFIX}{mock_server.id}"
693
+ tool = tool_from_id(tool_id, task=mock_task)
694
+
695
+ # Verify the tool is KilnTaskTool
696
+ assert isinstance(tool, KilnTaskTool)
697
+ assert tool._project_id == "test_project_id"
698
+ assert tool._tool_id == tool_id
699
+ assert tool._tool_server_model == mock_server
700
+
701
+ def test_tool_from_id_kiln_task_tool_no_task(self):
702
+ """Test that Kiln task tool ID without task raises ValueError."""
703
+ tool_id = f"{KILN_TASK_TOOL_ID_PREFIX}test_server"
704
+ with pytest.raises(
705
+ ValueError,
706
+ match=r"Unable to resolve tool from id.*Requires a parent project/task",
707
+ ):
708
+ tool_from_id(tool_id, task=None)
709
+
710
+ def test_tool_from_id_kiln_task_tool_no_project(self):
711
+ """Test that Kiln task tool ID with task but no project raises ValueError."""
712
+ # Create mock task without parent project
713
+ mock_task = Mock(spec=Task)
714
+ mock_task.parent_project.return_value = None
715
+
716
+ tool_id = f"{KILN_TASK_TOOL_ID_PREFIX}test_server"
717
+
718
+ with pytest.raises(
719
+ ValueError,
720
+ match=r"Unable to resolve tool from id.*Requires a parent project/task",
721
+ ):
722
+ tool_from_id(tool_id, task=mock_task)
723
+
724
+ def test_tool_from_id_kiln_task_tool_server_not_found(self):
725
+ """Test that Kiln task tool ID with server not found raises ValueError."""
726
+ # Create mock external tool server with different ID
727
+ mock_server = ExternalToolServer(
728
+ name="different_server",
729
+ type=ToolServerType.kiln_task,
730
+ description="Different Kiln task server",
731
+ properties={
732
+ "name": "different_tool",
733
+ "description": "A different task tool",
734
+ "task_id": "different_task_123",
735
+ "run_config_id": "different_config_456",
736
+ "is_archived": False,
737
+ },
738
+ )
739
+
740
+ # Create mock project with the external tool server
741
+ mock_project = Mock(spec=Project)
742
+ mock_project.id = "test_project_id"
743
+ mock_project.external_tool_servers.return_value = [mock_server]
744
+
745
+ # Create mock task with parent project
746
+ mock_task = Mock(spec=Task)
747
+ mock_task.parent_project.return_value = mock_project
748
+
749
+ # Use a tool ID with a server that doesn't exist in the project
750
+ tool_id = f"{KILN_TASK_TOOL_ID_PREFIX}nonexistent_server"
751
+
752
+ with pytest.raises(
753
+ ValueError,
754
+ match="Kiln Task External tool server not found: nonexistent_server in project ID test_project_id",
755
+ ):
756
+ tool_from_id(tool_id, task=mock_task)
@@ -1,9 +1,14 @@
1
+ from kiln_ai.datamodel.rag import RagConfig
1
2
  from kiln_ai.datamodel.task import Task
2
3
  from kiln_ai.datamodel.tool_id import (
4
+ KILN_TASK_TOOL_ID_PREFIX,
3
5
  MCP_LOCAL_TOOL_ID_PREFIX,
4
6
  MCP_REMOTE_TOOL_ID_PREFIX,
7
+ RAG_TOOL_ID_PREFIX,
5
8
  KilnBuiltInToolId,
9
+ kiln_task_server_id_from_tool_id,
6
10
  mcp_server_and_tool_name_from_id,
11
+ rag_config_id_from_id,
7
12
  )
8
13
  from kiln_ai.tools.base_tool import KilnToolInterface
9
14
  from kiln_ai.tools.built_in_tools.math_tools import (
@@ -12,6 +17,7 @@ from kiln_ai.tools.built_in_tools.math_tools import (
12
17
  MultiplyTool,
13
18
  SubtractTool,
14
19
  )
20
+ from kiln_ai.tools.kiln_task_tool import KilnTaskTool
15
21
  from kiln_ai.tools.mcp_server_tool import MCPServerTool
16
22
  from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
17
23
 
@@ -35,30 +41,77 @@ def tool_from_id(tool_id: str, task: Task | None = None) -> KilnToolInterface:
35
41
  case _:
36
42
  raise_exhaustive_enum_error(typed_tool_id)
37
43
 
38
- # Check MCP Server Tools
39
- if tool_id.startswith((MCP_REMOTE_TOOL_ID_PREFIX, MCP_LOCAL_TOOL_ID_PREFIX)):
44
+ # Check if this looks like an MCP or Kiln Task tool ID that requires a project
45
+ is_mcp_tool = tool_id.startswith(
46
+ (MCP_REMOTE_TOOL_ID_PREFIX, MCP_LOCAL_TOOL_ID_PREFIX)
47
+ )
48
+ is_kiln_task_tool = tool_id.startswith(KILN_TASK_TOOL_ID_PREFIX)
49
+
50
+ if is_mcp_tool or is_kiln_task_tool:
40
51
  project = task.parent_project() if task is not None else None
41
- if project is None:
52
+ if project is None or project.id is None:
42
53
  raise ValueError(
43
54
  f"Unable to resolve tool from id: {tool_id}. Requires a parent project/task."
44
55
  )
45
56
 
46
- # Get the tool server ID and tool name from the ID
47
- tool_server_id, tool_name = mcp_server_and_tool_name_from_id(tool_id)
57
+ # Check MCP Server Tools
58
+ if is_mcp_tool:
59
+ # Get the tool server ID and tool name from the ID
60
+ tool_server_id, tool_name = mcp_server_and_tool_name_from_id(
61
+ tool_id
62
+ ) # Fixed function name
63
+
64
+ server = next(
65
+ (
66
+ server
67
+ for server in project.external_tool_servers()
68
+ if server.id == tool_server_id
69
+ ),
70
+ None,
71
+ )
72
+ if server is None:
73
+ raise ValueError(
74
+ f"External tool server not found: {tool_server_id} in project ID {project.id}"
75
+ )
76
+
77
+ return MCPServerTool(server, tool_name)
78
+
79
+ # Check Kiln Task Tools
80
+ if is_kiln_task_tool:
81
+ server_id = kiln_task_server_id_from_tool_id(tool_id)
82
+
83
+ server = next(
84
+ (
85
+ server
86
+ for server in project.external_tool_servers()
87
+ if server.id == server_id
88
+ ),
89
+ None,
90
+ )
91
+ if server is None:
92
+ raise ValueError(
93
+ f"Kiln Task External tool server not found: {server_id} in project ID {project.id}"
94
+ )
95
+
96
+ return KilnTaskTool(project.id, tool_id, server)
97
+
98
+ elif tool_id.startswith(RAG_TOOL_ID_PREFIX):
99
+ project = task.parent_project() if task is not None else None
100
+ if project is None:
101
+ raise ValueError(
102
+ f"Unable to resolve tool from id: {tool_id}. Requires a parent project/task."
103
+ )
48
104
 
49
- server = next(
50
- (
51
- server
52
- for server in project.external_tool_servers()
53
- if server.id == tool_server_id
54
- ),
55
- None,
56
- )
57
- if server is None:
105
+ rag_config_id = rag_config_id_from_id(tool_id)
106
+ rag_config = RagConfig.from_id_and_parent_path(rag_config_id, project.path)
107
+ if rag_config is None:
58
108
  raise ValueError(
59
- f"External tool server not found: {tool_server_id} in project ID {project.id}"
109
+ f"RAG config not found: {rag_config_id} in project {project.id} for tool {tool_id}"
60
110
  )
61
111
 
62
- return MCPServerTool(server, tool_name)
112
+ # Lazy import to avoid circular dependency
113
+ from kiln_ai.tools.rag_tools import RagTool
114
+
115
+ return RagTool(tool_id, rag_config)
63
116
 
64
117
  raise ValueError(f"Tool ID {tool_id} not found in tool registry")
kiln_ai/utils/__init__.py CHANGED
@@ -5,8 +5,11 @@ Misc utilities used in the kiln_ai library.
5
5
  """
6
6
 
7
7
  from . import config, formatting
8
+ from .lock import AsyncLockManager, shared_async_lock_manager
8
9
 
9
10
  __all__ = [
11
+ "AsyncLockManager",
10
12
  "config",
11
13
  "formatting",
14
+ "shared_async_lock_manager",
12
15
  ]
@@ -1,7 +1,7 @@
1
1
  import asyncio
2
2
  import logging
3
3
  from dataclasses import dataclass
4
- from typing import AsyncGenerator, Awaitable, Callable, List, TypeVar
4
+ from typing import AsyncGenerator, Awaitable, Callable, Generic, List, TypeVar
5
5
 
6
6
  logger = logging.getLogger(__name__)
7
7
 
@@ -15,29 +15,66 @@ class Progress:
15
15
  errors: int
16
16
 
17
17
 
18
- class AsyncJobRunner:
19
- def __init__(self, concurrency: int = 1):
18
+ class AsyncJobRunnerObserver(Generic[T]):
19
+ async def on_error(self, job: T, error: Exception):
20
+ """
21
+ Called when a job raises an unhandled exception.
22
+ """
23
+ pass
24
+
25
+ async def on_success(self, job: T):
26
+ """
27
+ Called when a job completes successfully.
28
+ """
29
+ pass
30
+
31
+ async def on_job_start(self, job: T):
32
+ """
33
+ Called when a job starts.
34
+ """
35
+ pass
36
+
37
+
38
+ class AsyncJobRunner(Generic[T]):
39
+ def __init__(
40
+ self,
41
+ jobs: List[T],
42
+ run_job_fn: Callable[[T], Awaitable[bool]],
43
+ concurrency: int = 1,
44
+ observers: List[AsyncJobRunnerObserver[T]] | None = None,
45
+ ):
20
46
  if concurrency < 1:
21
47
  raise ValueError("concurrency must be ≥ 1")
22
48
  self.concurrency = concurrency
49
+ self.jobs = jobs
50
+ self.run_job_fn = run_job_fn
51
+ self.observers = observers or []
23
52
 
24
- async def run(
25
- self,
26
- jobs: List[T],
27
- run_job: Callable[[T], Awaitable[bool]],
28
- ) -> AsyncGenerator[Progress, None]:
53
+ async def notify_error(self, job: T, error: Exception):
54
+ for observer in self.observers:
55
+ await observer.on_error(job, error)
56
+
57
+ async def notify_success(self, job: T):
58
+ for observer in self.observers:
59
+ await observer.on_success(job)
60
+
61
+ async def notify_job_start(self, job: T):
62
+ for observer in self.observers:
63
+ await observer.on_job_start(job)
64
+
65
+ async def run(self) -> AsyncGenerator[Progress, None]:
29
66
  """
30
67
  Runs the jobs with parallel workers and yields progress updates.
31
68
  """
32
69
  complete = 0
33
70
  errors = 0
34
- total = len(jobs)
71
+ total = len(self.jobs)
35
72
 
36
73
  # Send initial status
37
74
  yield Progress(complete=complete, total=total, errors=errors)
38
75
 
39
76
  worker_queue: asyncio.Queue[T] = asyncio.Queue()
40
- for job in jobs:
77
+ for job in self.jobs:
41
78
  worker_queue.put_nowait(job)
42
79
 
43
80
  # simple status queue to return progress. True=success, False=error
@@ -46,7 +83,7 @@ class AsyncJobRunner:
46
83
  workers = []
47
84
  for _ in range(self.concurrency):
48
85
  task = asyncio.create_task(
49
- self._run_worker(worker_queue, status_queue, run_job),
86
+ self._run_worker(worker_queue, status_queue, self.run_job_fn),
50
87
  )
51
88
  workers.append(task)
52
89
 
@@ -64,7 +101,11 @@ class AsyncJobRunner:
64
101
  else:
65
102
  errors += 1
66
103
 
67
- yield Progress(complete=complete, total=total, errors=errors)
104
+ yield Progress(
105
+ complete=complete,
106
+ total=total,
107
+ errors=errors,
108
+ )
68
109
  except asyncio.TimeoutError:
69
110
  # Timeout is expected, just continue to recheck worker status
70
111
  # Don't love this but beats sentinels for reliability
@@ -82,7 +123,7 @@ class AsyncJobRunner:
82
123
  self,
83
124
  worker_queue: asyncio.Queue[T],
84
125
  status_queue: asyncio.Queue[bool],
85
- run_job: Callable[[T], Awaitable[bool]],
126
+ run_job_fn: Callable[[T], Awaitable[bool]],
86
127
  ):
87
128
  while True:
88
129
  try:
@@ -92,13 +133,17 @@ class AsyncJobRunner:
92
133
  break
93
134
 
94
135
  try:
95
- success = await run_job(job)
96
- except Exception:
136
+ await self.notify_job_start(job)
137
+ result = await run_job_fn(job)
138
+ if result:
139
+ await self.notify_success(job)
140
+ except Exception as e:
97
141
  logger.error("Job failed to complete", exc_info=True)
98
- success = False
142
+ await self.notify_error(job, e)
143
+ result = False
99
144
 
100
145
  try:
101
- await status_queue.put(success)
146
+ await status_queue.put(result)
102
147
  except Exception:
103
148
  logger.error("Failed to enqueue status for job", exc_info=True)
104
149
  finally:
kiln_ai/utils/config.py CHANGED
@@ -221,14 +221,14 @@ class Config:
221
221
  raise AttributeError(f"Config has no attribute '{name}'")
222
222
 
223
223
  @classmethod
224
- def settings_dir(cls, create=True):
224
+ def settings_dir(cls, create=True) -> str:
225
225
  settings_dir = os.path.join(Path.home(), ".kiln_ai")
226
226
  if create and not os.path.exists(settings_dir):
227
227
  os.makedirs(settings_dir)
228
228
  return settings_dir
229
229
 
230
230
  @classmethod
231
- def settings_path(cls, create=True):
231
+ def settings_path(cls, create=True) -> str:
232
232
  settings_dir = cls.settings_dir(create)
233
233
  return os.path.join(settings_dir, "settings.yaml")
234
234
 
kiln_ai/utils/env.py ADDED
@@ -0,0 +1,15 @@
1
+ import os
2
+ from contextlib import contextmanager
3
+
4
+
5
+ @contextmanager
6
+ def temporary_env(var_name: str, value: str):
7
+ old_value = os.environ.get(var_name)
8
+ os.environ[var_name] = value
9
+ try:
10
+ yield
11
+ finally:
12
+ if old_value is None:
13
+ os.environ.pop(var_name, None) # remove if it did not exist before
14
+ else:
15
+ os.environ[var_name] = old_value