kiln-ai 0.6.1__py3-none-any.whl → 0.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (44) hide show
  1. kiln_ai/adapters/__init__.py +2 -0
  2. kiln_ai/adapters/adapter_registry.py +19 -0
  3. kiln_ai/adapters/data_gen/test_data_gen_task.py +29 -21
  4. kiln_ai/adapters/fine_tune/__init__.py +14 -0
  5. kiln_ai/adapters/fine_tune/base_finetune.py +186 -0
  6. kiln_ai/adapters/fine_tune/dataset_formatter.py +187 -0
  7. kiln_ai/adapters/fine_tune/finetune_registry.py +11 -0
  8. kiln_ai/adapters/fine_tune/fireworks_finetune.py +308 -0
  9. kiln_ai/adapters/fine_tune/openai_finetune.py +205 -0
  10. kiln_ai/adapters/fine_tune/test_base_finetune.py +290 -0
  11. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +342 -0
  12. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +455 -0
  13. kiln_ai/adapters/fine_tune/test_openai_finetune.py +503 -0
  14. kiln_ai/adapters/langchain_adapters.py +103 -13
  15. kiln_ai/adapters/ml_model_list.py +239 -303
  16. kiln_ai/adapters/ollama_tools.py +115 -0
  17. kiln_ai/adapters/provider_tools.py +308 -0
  18. kiln_ai/adapters/repair/repair_task.py +4 -2
  19. kiln_ai/adapters/repair/test_repair_task.py +6 -11
  20. kiln_ai/adapters/test_langchain_adapter.py +229 -18
  21. kiln_ai/adapters/test_ollama_tools.py +42 -0
  22. kiln_ai/adapters/test_prompt_adaptors.py +7 -5
  23. kiln_ai/adapters/test_provider_tools.py +531 -0
  24. kiln_ai/adapters/test_structured_output.py +22 -43
  25. kiln_ai/datamodel/__init__.py +287 -24
  26. kiln_ai/datamodel/basemodel.py +122 -38
  27. kiln_ai/datamodel/model_cache.py +116 -0
  28. kiln_ai/datamodel/registry.py +31 -0
  29. kiln_ai/datamodel/test_basemodel.py +167 -4
  30. kiln_ai/datamodel/test_dataset_split.py +234 -0
  31. kiln_ai/datamodel/test_example_models.py +12 -0
  32. kiln_ai/datamodel/test_model_cache.py +244 -0
  33. kiln_ai/datamodel/test_models.py +215 -1
  34. kiln_ai/datamodel/test_registry.py +96 -0
  35. kiln_ai/utils/config.py +14 -1
  36. kiln_ai/utils/name_generator.py +125 -0
  37. kiln_ai/utils/test_name_geneator.py +47 -0
  38. kiln_ai-0.7.1.dist-info/METADATA +237 -0
  39. kiln_ai-0.7.1.dist-info/RECORD +58 -0
  40. {kiln_ai-0.6.1.dist-info → kiln_ai-0.7.1.dist-info}/WHEEL +1 -1
  41. kiln_ai/adapters/test_ml_model_list.py +0 -181
  42. kiln_ai-0.6.1.dist-info/METADATA +0 -88
  43. kiln_ai-0.6.1.dist-info/RECORD +0 -37
  44. {kiln_ai-0.6.1.dist-info → kiln_ai-0.7.1.dist-info}/licenses/LICENSE.txt +0 -0
@@ -1,18 +1,27 @@
1
+ import os
1
2
  from unittest.mock import AsyncMock, MagicMock, patch
2
3
 
4
+ import pytest
5
+ from langchain_aws import ChatBedrockConverse
3
6
  from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
7
+ from langchain_fireworks import ChatFireworks
4
8
  from langchain_groq import ChatGroq
9
+ from langchain_ollama import ChatOllama
10
+ from langchain_openai import ChatOpenAI
5
11
 
6
- from kiln_ai.adapters.langchain_adapters import LangChainPromptAdapter
12
+ from kiln_ai.adapters.langchain_adapters import (
13
+ LangchainAdapter,
14
+ get_structured_output_options,
15
+ langchain_model_from_provider,
16
+ )
17
+ from kiln_ai.adapters.ml_model_list import KilnModelProvider, ModelProviderName
7
18
  from kiln_ai.adapters.prompt_builders import SimpleChainOfThoughtPromptBuilder
8
19
  from kiln_ai.adapters.test_prompt_adaptors import build_test_task
9
20
 
10
21
 
11
22
  def test_langchain_adapter_munge_response(tmp_path):
12
23
  task = build_test_task(tmp_path)
13
- lca = LangChainPromptAdapter(
14
- kiln_task=task, model_name="llama_3_1_8b", provider="ollama"
15
- )
24
+ lca = LangchainAdapter(kiln_task=task, model_name="llama_3_1_8b", provider="ollama")
16
25
  # Mistral Large tool calling format is a bit different
17
26
  response = {
18
27
  "name": "task_response",
@@ -35,7 +44,7 @@ def test_langchain_adapter_infer_model_name(tmp_path):
35
44
  task = build_test_task(tmp_path)
36
45
  custom = ChatGroq(model="llama-3.1-8b-instant", groq_api_key="test")
37
46
 
38
- lca = LangChainPromptAdapter(kiln_task=task, custom_model=custom)
47
+ lca = LangchainAdapter(kiln_task=task, custom_model=custom)
39
48
 
40
49
  model_info = lca.adapter_info()
41
50
  assert model_info.model_name == "custom.langchain:llama-3.1-8b-instant"
@@ -45,9 +54,7 @@ def test_langchain_adapter_infer_model_name(tmp_path):
45
54
  def test_langchain_adapter_info(tmp_path):
46
55
  task = build_test_task(tmp_path)
47
56
 
48
- lca = LangChainPromptAdapter(
49
- kiln_task=task, model_name="llama_3_1_8b", provider="ollama"
50
- )
57
+ lca = LangchainAdapter(kiln_task=task, model_name="llama_3_1_8b", provider="ollama")
51
58
 
52
59
  model_info = lca.adapter_info()
53
60
  assert model_info.adapter_name == "kiln_langchain_adapter"
@@ -60,7 +67,7 @@ async def test_langchain_adapter_with_cot(tmp_path):
60
67
  task.output_json_schema = (
61
68
  '{"type": "object", "properties": {"count": {"type": "integer"}}}'
62
69
  )
63
- lca = LangChainPromptAdapter(
70
+ lca = LangchainAdapter(
64
71
  kiln_task=task,
65
72
  model_name="llama_3_1_8b",
66
73
  provider="ollama",
@@ -69,13 +76,13 @@ async def test_langchain_adapter_with_cot(tmp_path):
69
76
 
70
77
  # Mock the base model and its invoke method
71
78
  mock_base_model = MagicMock()
72
- mock_base_model.invoke.return_value = AIMessage(
73
- content="Chain of thought reasoning..."
79
+ mock_base_model.ainvoke = AsyncMock(
80
+ return_value=AIMessage(content="Chain of thought reasoning...")
74
81
  )
75
82
 
76
83
  # Create a separate mock for self.model()
77
84
  mock_model_instance = MagicMock()
78
- mock_model_instance.invoke.return_value = {"parsed": {"count": 1}}
85
+ mock_model_instance.ainvoke = AsyncMock(return_value={"parsed": {"count": 1}})
79
86
 
80
87
  # Mock the langchain_model_from function to return the base model
81
88
  mock_model_from = AsyncMock(return_value=mock_base_model)
@@ -85,14 +92,14 @@ async def test_langchain_adapter_with_cot(tmp_path):
85
92
  patch(
86
93
  "kiln_ai.adapters.langchain_adapters.langchain_model_from", mock_model_from
87
94
  ),
88
- patch.object(LangChainPromptAdapter, "model", return_value=mock_model_instance),
95
+ patch.object(LangchainAdapter, "model", return_value=mock_model_instance),
89
96
  ):
90
97
  response = await lca._run("test input")
91
98
 
92
99
  # First 3 messages are the same for both calls
93
100
  for invoke_args in [
94
- mock_base_model.invoke.call_args[0][0],
95
- mock_model_instance.invoke.call_args[0][0],
101
+ mock_base_model.ainvoke.call_args[0][0],
102
+ mock_model_instance.ainvoke.call_args[0][0],
96
103
  ]:
97
104
  assert isinstance(
98
105
  invoke_args[0], SystemMessage
@@ -107,11 +114,11 @@ async def test_langchain_adapter_with_cot(tmp_path):
107
114
  assert "step by step" in invoke_args[2].content
108
115
 
109
116
  # the COT should only have 3 messages
110
- assert len(mock_base_model.invoke.call_args[0][0]) == 3
111
- assert len(mock_model_instance.invoke.call_args[0][0]) == 5
117
+ assert len(mock_base_model.ainvoke.call_args[0][0]) == 3
118
+ assert len(mock_model_instance.ainvoke.call_args[0][0]) == 5
112
119
 
113
120
  # the final response should have the COT content and the final instructions
114
- invoke_args = mock_model_instance.invoke.call_args[0][0]
121
+ invoke_args = mock_model_instance.ainvoke.call_args[0][0]
115
122
  assert isinstance(invoke_args[3], AIMessage)
116
123
  assert "Chain of thought reasoning..." in invoke_args[3].content
117
124
  assert isinstance(invoke_args[4], SystemMessage)
@@ -122,3 +129,207 @@ async def test_langchain_adapter_with_cot(tmp_path):
122
129
  == "Chain of thought reasoning..."
123
130
  )
124
131
  assert response.output == {"count": 1}
132
+
133
+
134
+ async def test_get_structured_output_options():
135
+ # Mock the provider response
136
+ mock_provider = MagicMock()
137
+ mock_provider.adapter_options = {
138
+ "langchain": {
139
+ "with_structured_output_options": {
140
+ "force_json_response": True,
141
+ "max_retries": 3,
142
+ }
143
+ }
144
+ }
145
+
146
+ # Test with provider that has options
147
+ with patch(
148
+ "kiln_ai.adapters.langchain_adapters.kiln_model_provider_from",
149
+ AsyncMock(return_value=mock_provider),
150
+ ):
151
+ options = await get_structured_output_options("model_name", "provider")
152
+ assert options == {"force_json_response": True, "max_retries": 3}
153
+
154
+ # Test with provider that has no options
155
+ with patch(
156
+ "kiln_ai.adapters.langchain_adapters.kiln_model_provider_from",
157
+ AsyncMock(return_value=None),
158
+ ):
159
+ options = await get_structured_output_options("model_name", "provider")
160
+ assert options == {}
161
+
162
+
163
+ @pytest.mark.asyncio
164
+ async def test_langchain_model_from_provider_openai():
165
+ provider = KilnModelProvider(
166
+ name=ModelProviderName.openai, provider_options={"model": "gpt-4"}
167
+ )
168
+
169
+ with patch("kiln_ai.adapters.langchain_adapters.Config.shared") as mock_config:
170
+ mock_config.return_value.open_ai_api_key = "test_key"
171
+ model = await langchain_model_from_provider(provider, "gpt-4")
172
+ assert isinstance(model, ChatOpenAI)
173
+ assert model.model_name == "gpt-4"
174
+
175
+
176
+ @pytest.mark.asyncio
177
+ async def test_langchain_model_from_provider_groq():
178
+ provider = KilnModelProvider(
179
+ name=ModelProviderName.groq, provider_options={"model": "mixtral-8x7b"}
180
+ )
181
+
182
+ with patch("kiln_ai.adapters.langchain_adapters.Config.shared") as mock_config:
183
+ mock_config.return_value.groq_api_key = "test_key"
184
+ model = await langchain_model_from_provider(provider, "mixtral-8x7b")
185
+ assert isinstance(model, ChatGroq)
186
+ assert model.model_name == "mixtral-8x7b"
187
+
188
+
189
+ @pytest.mark.asyncio
190
+ async def test_langchain_model_from_provider_bedrock():
191
+ provider = KilnModelProvider(
192
+ name=ModelProviderName.amazon_bedrock,
193
+ provider_options={"model": "anthropic.claude-v2", "region_name": "us-east-1"},
194
+ )
195
+
196
+ with patch("kiln_ai.adapters.langchain_adapters.Config.shared") as mock_config:
197
+ mock_config.return_value.bedrock_access_key = "test_access"
198
+ mock_config.return_value.bedrock_secret_key = "test_secret"
199
+ model = await langchain_model_from_provider(provider, "anthropic.claude-v2")
200
+ assert isinstance(model, ChatBedrockConverse)
201
+ assert os.environ.get("AWS_ACCESS_KEY_ID") == "test_access"
202
+ assert os.environ.get("AWS_SECRET_ACCESS_KEY") == "test_secret"
203
+
204
+
205
+ @pytest.mark.asyncio
206
+ async def test_langchain_model_from_provider_fireworks():
207
+ provider = KilnModelProvider(
208
+ name=ModelProviderName.fireworks_ai, provider_options={"model": "mixtral-8x7b"}
209
+ )
210
+
211
+ with patch("kiln_ai.adapters.langchain_adapters.Config.shared") as mock_config:
212
+ mock_config.return_value.fireworks_api_key = "test_key"
213
+ model = await langchain_model_from_provider(provider, "mixtral-8x7b")
214
+ assert isinstance(model, ChatFireworks)
215
+
216
+
217
+ @pytest.mark.asyncio
218
+ async def test_langchain_model_from_provider_ollama():
219
+ provider = KilnModelProvider(
220
+ name=ModelProviderName.ollama,
221
+ provider_options={"model": "llama2", "model_aliases": ["llama2-uncensored"]},
222
+ )
223
+
224
+ mock_connection = MagicMock()
225
+ with (
226
+ patch(
227
+ "kiln_ai.adapters.langchain_adapters.get_ollama_connection",
228
+ return_value=AsyncMock(return_value=mock_connection),
229
+ ),
230
+ patch(
231
+ "kiln_ai.adapters.langchain_adapters.ollama_model_installed",
232
+ return_value=True,
233
+ ),
234
+ patch(
235
+ "kiln_ai.adapters.langchain_adapters.ollama_base_url",
236
+ return_value="http://localhost:11434",
237
+ ),
238
+ ):
239
+ model = await langchain_model_from_provider(provider, "llama2")
240
+ assert isinstance(model, ChatOllama)
241
+ assert model.model == "llama2"
242
+
243
+
244
+ @pytest.mark.asyncio
245
+ async def test_langchain_model_from_provider_invalid():
246
+ provider = KilnModelProvider.model_construct(
247
+ name="invalid_provider", provider_options={}
248
+ )
249
+
250
+ with pytest.raises(ValueError, match="Invalid model or provider"):
251
+ await langchain_model_from_provider(provider, "test_model")
252
+
253
+
254
+ @pytest.mark.asyncio
255
+ async def test_langchain_adapter_model_caching(tmp_path):
256
+ task = build_test_task(tmp_path)
257
+ custom_model = ChatGroq(model="mixtral-8x7b", groq_api_key="test")
258
+
259
+ adapter = LangchainAdapter(kiln_task=task, custom_model=custom_model)
260
+
261
+ # First call should return the cached model
262
+ model1 = await adapter.model()
263
+ assert model1 is custom_model
264
+
265
+ # Second call should return the same cached instance
266
+ model2 = await adapter.model()
267
+ assert model2 is model1
268
+
269
+
270
+ @pytest.mark.asyncio
271
+ async def test_langchain_adapter_model_structured_output(tmp_path):
272
+ task = build_test_task(tmp_path)
273
+ task.output_json_schema = """
274
+ {
275
+ "type": "object",
276
+ "properties": {
277
+ "count": {"type": "integer"}
278
+ }
279
+ }
280
+ """
281
+
282
+ mock_model = MagicMock()
283
+ mock_model.with_structured_output = MagicMock(return_value="structured_model")
284
+
285
+ adapter = LangchainAdapter(
286
+ kiln_task=task, model_name="test_model", provider="test_provider"
287
+ )
288
+
289
+ with (
290
+ patch(
291
+ "kiln_ai.adapters.langchain_adapters.langchain_model_from",
292
+ AsyncMock(return_value=mock_model),
293
+ ),
294
+ patch(
295
+ "kiln_ai.adapters.langchain_adapters.get_structured_output_options",
296
+ AsyncMock(return_value={"option1": "value1"}),
297
+ ),
298
+ ):
299
+ model = await adapter.model()
300
+
301
+ # Verify the model was configured with structured output
302
+ mock_model.with_structured_output.assert_called_once_with(
303
+ {
304
+ "type": "object",
305
+ "properties": {"count": {"type": "integer"}},
306
+ "title": "task_response",
307
+ "description": "A response from the task",
308
+ },
309
+ include_raw=True,
310
+ option1="value1",
311
+ )
312
+ assert model == "structured_model"
313
+
314
+
315
+ @pytest.mark.asyncio
316
+ async def test_langchain_adapter_model_no_structured_output_support(tmp_path):
317
+ task = build_test_task(tmp_path)
318
+ task.output_json_schema = (
319
+ '{"type": "object", "properties": {"count": {"type": "integer"}}}'
320
+ )
321
+
322
+ mock_model = MagicMock()
323
+ # Remove with_structured_output method
324
+ del mock_model.with_structured_output
325
+
326
+ adapter = LangchainAdapter(
327
+ kiln_task=task, model_name="test_model", provider="test_provider"
328
+ )
329
+
330
+ with patch(
331
+ "kiln_ai.adapters.langchain_adapters.langchain_model_from",
332
+ AsyncMock(return_value=mock_model),
333
+ ):
334
+ with pytest.raises(ValueError, match="does not support structured output"):
335
+ await adapter.model()
@@ -0,0 +1,42 @@
1
+ import json
2
+
3
+ from kiln_ai.adapters.ollama_tools import (
4
+ OllamaConnection,
5
+ ollama_model_installed,
6
+ parse_ollama_tags,
7
+ )
8
+
9
+
10
+ def test_parse_ollama_tags_no_models():
11
+ json_response = '{"models":[{"name":"scosman_net","model":"scosman_net:latest"},{"name":"phi3.5:latest","model":"phi3.5:latest","modified_at":"2024-10-02T12:04:35.191519822-04:00","size":2176178843,"digest":"61819fb370a3c1a9be6694869331e5f85f867a079e9271d66cb223acb81d04ba","details":{"parent_model":"","format":"gguf","family":"phi3","families":["phi3"],"parameter_size":"3.8B","quantization_level":"Q4_0"}},{"name":"gemma2:2b","model":"gemma2:2b","modified_at":"2024-09-09T16:46:38.64348929-04:00","size":1629518495,"digest":"8ccf136fdd5298f3ffe2d69862750ea7fb56555fa4d5b18c04e3fa4d82ee09d7","details":{"parent_model":"","format":"gguf","family":"gemma2","families":["gemma2"],"parameter_size":"2.6B","quantization_level":"Q4_0"}},{"name":"llama3.1:latest","model":"llama3.1:latest","modified_at":"2024-09-01T17:19:43.481523695-04:00","size":4661230720,"digest":"f66fc8dc39ea206e03ff6764fcc696b1b4dfb693f0b6ef751731dd4e6269046e","details":{"parent_model":"","format":"gguf","family":"llama","families":["llama"],"parameter_size":"8.0B","quantization_level":"Q4_0"}}]}'
12
+ tags = json.loads(json_response)
13
+ print(json.dumps(tags, indent=2))
14
+ conn = parse_ollama_tags(tags)
15
+ assert "phi3.5:latest" in conn.supported_models
16
+ assert "gemma2:2b" in conn.supported_models
17
+ assert "llama3.1:latest" in conn.supported_models
18
+ assert "scosman_net:latest" in conn.untested_models
19
+
20
+
21
+ def test_parse_ollama_tags_only_untested_models():
22
+ json_response = '{"models":[{"name":"scosman_net","model":"scosman_net:latest"}]}'
23
+ tags = json.loads(json_response)
24
+ conn = parse_ollama_tags(tags)
25
+ assert conn.supported_models == []
26
+ assert conn.untested_models == ["scosman_net:latest"]
27
+
28
+
29
+ def test_ollama_model_installed():
30
+ conn = OllamaConnection(
31
+ supported_models=["phi3.5:latest", "gemma2:2b", "llama3.1:latest"],
32
+ message="Connected",
33
+ untested_models=["scosman_net:latest"],
34
+ )
35
+ assert ollama_model_installed(conn, "phi3.5:latest")
36
+ assert ollama_model_installed(conn, "phi3.5")
37
+ assert ollama_model_installed(conn, "gemma2:2b")
38
+ assert ollama_model_installed(conn, "llama3.1:latest")
39
+ assert ollama_model_installed(conn, "llama3.1")
40
+ assert ollama_model_installed(conn, "scosman_net:latest")
41
+ assert ollama_model_installed(conn, "scosman_net")
42
+ assert not ollama_model_installed(conn, "unknown_model")
@@ -5,8 +5,10 @@ import pytest
5
5
  from langchain_core.language_models.fake_chat_models import FakeListChatModel
6
6
 
7
7
  import kiln_ai.datamodel as datamodel
8
- from kiln_ai.adapters.langchain_adapters import LangChainPromptAdapter
9
- from kiln_ai.adapters.ml_model_list import built_in_models, ollama_online
8
+ from kiln_ai.adapters.adapter_registry import adapter_for_task
9
+ from kiln_ai.adapters.langchain_adapters import LangchainAdapter
10
+ from kiln_ai.adapters.ml_model_list import built_in_models
11
+ from kiln_ai.adapters.ollama_tools import ollama_online
10
12
  from kiln_ai.adapters.prompt_builders import (
11
13
  BasePromptBuilder,
12
14
  SimpleChainOfThoughtPromptBuilder,
@@ -106,7 +108,7 @@ async def test_amazon_bedrock(tmp_path):
106
108
  async def test_mock(tmp_path):
107
109
  task = build_test_task(tmp_path)
108
110
  mockChatModel = FakeListChatModel(responses=["mock response"])
109
- adapter = LangChainPromptAdapter(task, custom_model=mockChatModel)
111
+ adapter = LangchainAdapter(task, custom_model=mockChatModel)
110
112
  run = await adapter.invoke("You are a mock, send me the response!")
111
113
  assert "mock response" in run.output.output
112
114
 
@@ -114,7 +116,7 @@ async def test_mock(tmp_path):
114
116
  async def test_mock_returning_run(tmp_path):
115
117
  task = build_test_task(tmp_path)
116
118
  mockChatModel = FakeListChatModel(responses=["mock response"])
117
- adapter = LangChainPromptAdapter(task, custom_model=mockChatModel)
119
+ adapter = LangchainAdapter(task, custom_model=mockChatModel)
118
120
  run = await adapter.invoke("You are a mock, send me the response!")
119
121
  assert run.output.output == "mock response"
120
122
  assert run is not None
@@ -192,7 +194,7 @@ async def run_simple_task(
192
194
  provider: str,
193
195
  prompt_builder: BasePromptBuilder | None = None,
194
196
  ) -> datamodel.TaskRun:
195
- adapter = LangChainPromptAdapter(
197
+ adapter = adapter_for_task(
196
198
  task, model_name=model_name, provider=provider, prompt_builder=prompt_builder
197
199
  )
198
200