kiln-ai 0.5.4__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/base_adapter.py +24 -35
- kiln_ai/adapters/data_gen/data_gen_prompts.py +73 -0
- kiln_ai/adapters/data_gen/data_gen_task.py +117 -0
- kiln_ai/adapters/data_gen/test_data_gen_task.py +292 -0
- kiln_ai/adapters/langchain_adapters.py +39 -7
- kiln_ai/adapters/ml_model_list.py +68 -1
- kiln_ai/adapters/prompt_builders.py +66 -0
- kiln_ai/adapters/repair/test_repair_task.py +4 -1
- kiln_ai/adapters/test_langchain_adapter.py +73 -0
- kiln_ai/adapters/test_ml_model_list.py +56 -0
- kiln_ai/adapters/test_prompt_adaptors.py +54 -18
- kiln_ai/adapters/test_prompt_builders.py +97 -7
- kiln_ai/adapters/test_saving_adapter_results.py +16 -6
- kiln_ai/adapters/test_structured_output.py +33 -5
- kiln_ai/datamodel/__init__.py +28 -7
- kiln_ai/datamodel/json_schema.py +1 -0
- kiln_ai/datamodel/test_models.py +44 -8
- kiln_ai/utils/config.py +3 -2
- kiln_ai/utils/test_config.py +7 -0
- {kiln_ai-0.5.4.dist-info → kiln_ai-0.6.0.dist-info}/METADATA +41 -7
- kiln_ai-0.6.0.dist-info/RECORD +36 -0
- {kiln_ai-0.5.4.dist-info → kiln_ai-0.6.0.dist-info}/WHEEL +1 -1
- kiln_ai-0.5.4.dist-info/RECORD +0 -33
- {kiln_ai-0.5.4.dist-info → kiln_ai-0.6.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -69,6 +69,7 @@ class ModelName(str, Enum):
|
|
|
69
69
|
gemma_2_2b = "gemma_2_2b"
|
|
70
70
|
gemma_2_9b = "gemma_2_9b"
|
|
71
71
|
gemma_2_27b = "gemma_2_27b"
|
|
72
|
+
claude_3_5_haiku = "claude_3_5_haiku"
|
|
72
73
|
claude_3_5_sonnet = "claude_3_5_sonnet"
|
|
73
74
|
gemini_1_5_flash = "gemini_1_5_flash"
|
|
74
75
|
gemini_1_5_flash_8b = "gemini_1_5_flash_8b"
|
|
@@ -88,6 +89,7 @@ class KilnModelProvider(BaseModel):
|
|
|
88
89
|
|
|
89
90
|
name: ModelProviderName
|
|
90
91
|
supports_structured_output: bool = True
|
|
92
|
+
supports_data_gen: bool = True
|
|
91
93
|
provider_options: Dict = {}
|
|
92
94
|
|
|
93
95
|
|
|
@@ -143,6 +145,18 @@ built_in_models: List[KilnModel] = [
|
|
|
143
145
|
),
|
|
144
146
|
],
|
|
145
147
|
),
|
|
148
|
+
# Claude 3.5 Haiku
|
|
149
|
+
KilnModel(
|
|
150
|
+
family=ModelFamily.claude,
|
|
151
|
+
name=ModelName.claude_3_5_haiku,
|
|
152
|
+
friendly_name="Claude 3.5 Haiku",
|
|
153
|
+
providers=[
|
|
154
|
+
KilnModelProvider(
|
|
155
|
+
name=ModelProviderName.openrouter,
|
|
156
|
+
provider_options={"model": "anthropic/claude-3-5-haiku"},
|
|
157
|
+
),
|
|
158
|
+
],
|
|
159
|
+
),
|
|
146
160
|
# Claude 3.5 Sonnet
|
|
147
161
|
KilnModel(
|
|
148
162
|
family=ModelFamily.claude,
|
|
@@ -163,6 +177,8 @@ built_in_models: List[KilnModel] = [
|
|
|
163
177
|
providers=[
|
|
164
178
|
KilnModelProvider(
|
|
165
179
|
name=ModelProviderName.openrouter,
|
|
180
|
+
supports_structured_output=False, # it should, but doesn't work on openrouter
|
|
181
|
+
supports_data_gen=False, # doesn't work on openrouter
|
|
166
182
|
provider_options={"model": "google/gemini-pro-1.5"},
|
|
167
183
|
),
|
|
168
184
|
],
|
|
@@ -175,6 +191,7 @@ built_in_models: List[KilnModel] = [
|
|
|
175
191
|
providers=[
|
|
176
192
|
KilnModelProvider(
|
|
177
193
|
name=ModelProviderName.openrouter,
|
|
194
|
+
supports_data_gen=False,
|
|
178
195
|
provider_options={"model": "google/gemini-flash-1.5"},
|
|
179
196
|
),
|
|
180
197
|
],
|
|
@@ -187,6 +204,8 @@ built_in_models: List[KilnModel] = [
|
|
|
187
204
|
providers=[
|
|
188
205
|
KilnModelProvider(
|
|
189
206
|
name=ModelProviderName.openrouter,
|
|
207
|
+
supports_structured_output=False,
|
|
208
|
+
supports_data_gen=False,
|
|
190
209
|
provider_options={"model": "google/gemini-flash-1.5-8b"},
|
|
191
210
|
),
|
|
192
211
|
],
|
|
@@ -200,6 +219,7 @@ built_in_models: List[KilnModel] = [
|
|
|
200
219
|
KilnModelProvider(
|
|
201
220
|
name=ModelProviderName.openrouter,
|
|
202
221
|
supports_structured_output=False,
|
|
222
|
+
supports_data_gen=False,
|
|
203
223
|
provider_options={"model": "nvidia/llama-3.1-nemotron-70b-instruct"},
|
|
204
224
|
),
|
|
205
225
|
],
|
|
@@ -217,6 +237,7 @@ built_in_models: List[KilnModel] = [
|
|
|
217
237
|
KilnModelProvider(
|
|
218
238
|
name=ModelProviderName.amazon_bedrock,
|
|
219
239
|
supports_structured_output=False,
|
|
240
|
+
supports_data_gen=False,
|
|
220
241
|
provider_options={
|
|
221
242
|
"model": "meta.llama3-1-8b-instruct-v1:0",
|
|
222
243
|
"region_name": "us-west-2", # Llama 3.1 only in west-2
|
|
@@ -224,6 +245,7 @@ built_in_models: List[KilnModel] = [
|
|
|
224
245
|
),
|
|
225
246
|
KilnModelProvider(
|
|
226
247
|
name=ModelProviderName.ollama,
|
|
248
|
+
supports_data_gen=False,
|
|
227
249
|
provider_options={
|
|
228
250
|
"model": "llama3.1:8b",
|
|
229
251
|
"model_aliases": ["llama3.1"], # 8b is default
|
|
@@ -232,6 +254,7 @@ built_in_models: List[KilnModel] = [
|
|
|
232
254
|
KilnModelProvider(
|
|
233
255
|
name=ModelProviderName.openrouter,
|
|
234
256
|
supports_structured_output=False,
|
|
257
|
+
supports_data_gen=False,
|
|
235
258
|
provider_options={"model": "meta-llama/llama-3.1-8b-instruct"},
|
|
236
259
|
),
|
|
237
260
|
],
|
|
@@ -248,7 +271,9 @@ built_in_models: List[KilnModel] = [
|
|
|
248
271
|
),
|
|
249
272
|
KilnModelProvider(
|
|
250
273
|
name=ModelProviderName.amazon_bedrock,
|
|
274
|
+
# not sure how AWS manages to break this, but it's not working
|
|
251
275
|
supports_structured_output=False,
|
|
276
|
+
supports_data_gen=False,
|
|
252
277
|
provider_options={
|
|
253
278
|
"model": "meta.llama3-1-70b-instruct-v1:0",
|
|
254
279
|
"region_name": "us-west-2", # Llama 3.1 only in west-2
|
|
@@ -272,6 +297,7 @@ built_in_models: List[KilnModel] = [
|
|
|
272
297
|
providers=[
|
|
273
298
|
KilnModelProvider(
|
|
274
299
|
name=ModelProviderName.amazon_bedrock,
|
|
300
|
+
supports_data_gen=False,
|
|
275
301
|
provider_options={
|
|
276
302
|
"model": "meta.llama3-1-405b-instruct-v1:0",
|
|
277
303
|
"region_name": "us-west-2", # Llama 3.1 only in west-2
|
|
@@ -331,8 +357,15 @@ built_in_models: List[KilnModel] = [
|
|
|
331
357
|
KilnModelProvider(
|
|
332
358
|
name=ModelProviderName.openrouter,
|
|
333
359
|
supports_structured_output=False,
|
|
360
|
+
supports_data_gen=False,
|
|
334
361
|
provider_options={"model": "meta-llama/llama-3.2-3b-instruct"},
|
|
335
362
|
),
|
|
363
|
+
KilnModelProvider(
|
|
364
|
+
name=ModelProviderName.ollama,
|
|
365
|
+
supports_structured_output=False,
|
|
366
|
+
supports_data_gen=False,
|
|
367
|
+
provider_options={"model": "llama3.2"},
|
|
368
|
+
),
|
|
336
369
|
],
|
|
337
370
|
),
|
|
338
371
|
# Llama 3.2 11B
|
|
@@ -344,8 +377,15 @@ built_in_models: List[KilnModel] = [
|
|
|
344
377
|
KilnModelProvider(
|
|
345
378
|
name=ModelProviderName.openrouter,
|
|
346
379
|
supports_structured_output=False,
|
|
380
|
+
supports_data_gen=False,
|
|
347
381
|
provider_options={"model": "meta-llama/llama-3.2-11b-vision-instruct"},
|
|
348
382
|
),
|
|
383
|
+
KilnModelProvider(
|
|
384
|
+
name=ModelProviderName.ollama,
|
|
385
|
+
supports_structured_output=False,
|
|
386
|
+
supports_data_gen=False,
|
|
387
|
+
provider_options={"model": "llama3.2-vision"},
|
|
388
|
+
),
|
|
349
389
|
],
|
|
350
390
|
),
|
|
351
391
|
# Llama 3.2 90B
|
|
@@ -357,8 +397,15 @@ built_in_models: List[KilnModel] = [
|
|
|
357
397
|
KilnModelProvider(
|
|
358
398
|
name=ModelProviderName.openrouter,
|
|
359
399
|
supports_structured_output=False,
|
|
400
|
+
supports_data_gen=False,
|
|
360
401
|
provider_options={"model": "meta-llama/llama-3.2-90b-vision-instruct"},
|
|
361
402
|
),
|
|
403
|
+
KilnModelProvider(
|
|
404
|
+
name=ModelProviderName.ollama,
|
|
405
|
+
supports_structured_output=False,
|
|
406
|
+
supports_data_gen=False,
|
|
407
|
+
provider_options={"model": "llama3.2-vision:90b"},
|
|
408
|
+
),
|
|
362
409
|
],
|
|
363
410
|
),
|
|
364
411
|
# Phi 3.5
|
|
@@ -371,10 +418,13 @@ built_in_models: List[KilnModel] = [
|
|
|
371
418
|
KilnModelProvider(
|
|
372
419
|
name=ModelProviderName.ollama,
|
|
373
420
|
supports_structured_output=False,
|
|
421
|
+
supports_data_gen=False,
|
|
374
422
|
provider_options={"model": "phi3.5"},
|
|
375
423
|
),
|
|
376
424
|
KilnModelProvider(
|
|
377
425
|
name=ModelProviderName.openrouter,
|
|
426
|
+
supports_structured_output=False,
|
|
427
|
+
supports_data_gen=False,
|
|
378
428
|
provider_options={"model": "microsoft/phi-3.5-mini-128k-instruct"},
|
|
379
429
|
),
|
|
380
430
|
],
|
|
@@ -389,6 +439,7 @@ built_in_models: List[KilnModel] = [
|
|
|
389
439
|
KilnModelProvider(
|
|
390
440
|
name=ModelProviderName.ollama,
|
|
391
441
|
supports_structured_output=False,
|
|
442
|
+
supports_data_gen=False,
|
|
392
443
|
provider_options={
|
|
393
444
|
"model": "gemma2:2b",
|
|
394
445
|
},
|
|
@@ -404,12 +455,14 @@ built_in_models: List[KilnModel] = [
|
|
|
404
455
|
providers=[
|
|
405
456
|
KilnModelProvider(
|
|
406
457
|
name=ModelProviderName.ollama,
|
|
458
|
+
supports_data_gen=False,
|
|
407
459
|
provider_options={
|
|
408
460
|
"model": "gemma2:9b",
|
|
409
461
|
},
|
|
410
462
|
),
|
|
411
463
|
KilnModelProvider(
|
|
412
464
|
name=ModelProviderName.openrouter,
|
|
465
|
+
supports_data_gen=False,
|
|
413
466
|
provider_options={"model": "google/gemma-2-9b-it"},
|
|
414
467
|
),
|
|
415
468
|
],
|
|
@@ -423,12 +476,14 @@ built_in_models: List[KilnModel] = [
|
|
|
423
476
|
providers=[
|
|
424
477
|
KilnModelProvider(
|
|
425
478
|
name=ModelProviderName.ollama,
|
|
479
|
+
supports_data_gen=False,
|
|
426
480
|
provider_options={
|
|
427
481
|
"model": "gemma2:27b",
|
|
428
482
|
},
|
|
429
483
|
),
|
|
430
484
|
KilnModelProvider(
|
|
431
485
|
name=ModelProviderName.openrouter,
|
|
486
|
+
supports_data_gen=False,
|
|
432
487
|
provider_options={"model": "google/gemma-2-27b-it"},
|
|
433
488
|
),
|
|
434
489
|
],
|
|
@@ -436,6 +491,19 @@ built_in_models: List[KilnModel] = [
|
|
|
436
491
|
]
|
|
437
492
|
|
|
438
493
|
|
|
494
|
+
def get_model_and_provider(
|
|
495
|
+
model_name: str, provider_name: str
|
|
496
|
+
) -> tuple[KilnModel | None, KilnModelProvider | None]:
|
|
497
|
+
model = next(filter(lambda m: m.name == model_name, built_in_models), None)
|
|
498
|
+
if model is None:
|
|
499
|
+
return None, None
|
|
500
|
+
provider = next(filter(lambda p: p.name == provider_name, model.providers), None)
|
|
501
|
+
# all or nothing
|
|
502
|
+
if provider is None or model is None:
|
|
503
|
+
return None, None
|
|
504
|
+
return model, provider
|
|
505
|
+
|
|
506
|
+
|
|
439
507
|
def provider_name_from_id(id: str) -> str:
|
|
440
508
|
"""
|
|
441
509
|
Converts a provider ID to its human-readable name.
|
|
@@ -674,7 +742,6 @@ def parse_ollama_tags(tags: Any) -> OllamaConnection | None:
|
|
|
674
742
|
models = tags["models"]
|
|
675
743
|
if isinstance(models, list):
|
|
676
744
|
model_names = [model["model"] for model in models]
|
|
677
|
-
print(f"model_names: {model_names}")
|
|
678
745
|
available_supported_models = [
|
|
679
746
|
model
|
|
680
747
|
for model in model_names
|
|
@@ -54,6 +54,28 @@ class BasePromptBuilder(metaclass=ABCMeta):
|
|
|
54
54
|
|
|
55
55
|
return f"The input is:\n{input}"
|
|
56
56
|
|
|
57
|
+
def chain_of_thought_prompt(self) -> str | None:
|
|
58
|
+
"""Build and return the chain of thought prompt string.
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
str: The constructed chain of thought prompt.
|
|
62
|
+
"""
|
|
63
|
+
return None
|
|
64
|
+
|
|
65
|
+
def build_prompt_for_ui(self) -> str:
|
|
66
|
+
"""Build a prompt for the UI. It includes additional instructions (like chain of thought), even if they are passed to the model in stages.
|
|
67
|
+
|
|
68
|
+
Designed for end-user consumption, not for model consumption.
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
str: The constructed prompt string.
|
|
72
|
+
"""
|
|
73
|
+
base_prompt = self.build_prompt()
|
|
74
|
+
cot_prompt = self.chain_of_thought_prompt()
|
|
75
|
+
if cot_prompt:
|
|
76
|
+
base_prompt += "\n# Thinking Instructions\n\n" + cot_prompt
|
|
77
|
+
return base_prompt
|
|
78
|
+
|
|
57
79
|
|
|
58
80
|
class SimplePromptBuilder(BasePromptBuilder):
|
|
59
81
|
"""A basic prompt builder that combines task instruction with requirements."""
|
|
@@ -187,11 +209,49 @@ class RepairsPromptBuilder(MultiShotPromptBuilder):
|
|
|
187
209
|
return prompt_section
|
|
188
210
|
|
|
189
211
|
|
|
212
|
+
def chain_of_thought_prompt(task: Task) -> str | None:
|
|
213
|
+
"""Standard implementation to build and return the chain of thought prompt string.
|
|
214
|
+
|
|
215
|
+
Returns:
|
|
216
|
+
str: The constructed chain of thought prompt.
|
|
217
|
+
"""
|
|
218
|
+
|
|
219
|
+
cot_instruction = task.thinking_instruction
|
|
220
|
+
if not cot_instruction:
|
|
221
|
+
cot_instruction = "Think step by step, explaining your reasoning."
|
|
222
|
+
|
|
223
|
+
return cot_instruction
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
class SimpleChainOfThoughtPromptBuilder(SimplePromptBuilder):
|
|
227
|
+
"""A prompt builder that includes a chain of thought prompt on top of the simple prompt."""
|
|
228
|
+
|
|
229
|
+
def chain_of_thought_prompt(self) -> str | None:
|
|
230
|
+
return chain_of_thought_prompt(self.task)
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
class FewShotChainOfThoughtPromptBuilder(FewShotPromptBuilder):
|
|
234
|
+
"""A prompt builder that includes a chain of thought prompt on top of the few shot prompt."""
|
|
235
|
+
|
|
236
|
+
def chain_of_thought_prompt(self) -> str | None:
|
|
237
|
+
return chain_of_thought_prompt(self.task)
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
class MultiShotChainOfThoughtPromptBuilder(MultiShotPromptBuilder):
|
|
241
|
+
"""A prompt builder that includes a chain of thought prompt on top of the multi shot prompt."""
|
|
242
|
+
|
|
243
|
+
def chain_of_thought_prompt(self) -> str | None:
|
|
244
|
+
return chain_of_thought_prompt(self.task)
|
|
245
|
+
|
|
246
|
+
|
|
190
247
|
prompt_builder_registry = {
|
|
191
248
|
"simple_prompt_builder": SimplePromptBuilder,
|
|
192
249
|
"multi_shot_prompt_builder": MultiShotPromptBuilder,
|
|
193
250
|
"few_shot_prompt_builder": FewShotPromptBuilder,
|
|
194
251
|
"repairs_prompt_builder": RepairsPromptBuilder,
|
|
252
|
+
"simple_chain_of_thought_prompt_builder": SimpleChainOfThoughtPromptBuilder,
|
|
253
|
+
"few_shot_chain_of_thought_prompt_builder": FewShotChainOfThoughtPromptBuilder,
|
|
254
|
+
"multi_shot_chain_of_thought_prompt_builder": MultiShotChainOfThoughtPromptBuilder,
|
|
195
255
|
}
|
|
196
256
|
|
|
197
257
|
|
|
@@ -217,5 +277,11 @@ def prompt_builder_from_ui_name(ui_name: str) -> type[BasePromptBuilder]:
|
|
|
217
277
|
return MultiShotPromptBuilder
|
|
218
278
|
case "repairs":
|
|
219
279
|
return RepairsPromptBuilder
|
|
280
|
+
case "simple_chain_of_thought":
|
|
281
|
+
return SimpleChainOfThoughtPromptBuilder
|
|
282
|
+
case "few_shot_chain_of_thought":
|
|
283
|
+
return FewShotChainOfThoughtPromptBuilder
|
|
284
|
+
case "multi_shot_chain_of_thought":
|
|
285
|
+
return MultiShotChainOfThoughtPromptBuilder
|
|
220
286
|
case _:
|
|
221
287
|
raise ValueError(f"Unknown prompt builder: {ui_name}")
|
|
@@ -5,6 +5,7 @@ from unittest.mock import AsyncMock, patch
|
|
|
5
5
|
import pytest
|
|
6
6
|
from pydantic import ValidationError
|
|
7
7
|
|
|
8
|
+
from kiln_ai.adapters.base_adapter import RunOutput
|
|
8
9
|
from kiln_ai.adapters.langchain_adapters import (
|
|
9
10
|
LangChainPromptAdapter,
|
|
10
11
|
)
|
|
@@ -222,7 +223,9 @@ async def test_mocked_repair_task_run(sample_task, sample_task_run, sample_repai
|
|
|
222
223
|
with patch.object(
|
|
223
224
|
LangChainPromptAdapter, "_run", new_callable=AsyncMock
|
|
224
225
|
) as mock_run:
|
|
225
|
-
mock_run.return_value =
|
|
226
|
+
mock_run.return_value = RunOutput(
|
|
227
|
+
output=mocked_output, intermediate_outputs=None
|
|
228
|
+
)
|
|
226
229
|
|
|
227
230
|
adapter = LangChainPromptAdapter(
|
|
228
231
|
repair_task, model_name="llama_3_1_8b", provider="groq"
|
|
@@ -1,6 +1,10 @@
|
|
|
1
|
+
from unittest.mock import AsyncMock, MagicMock, patch
|
|
2
|
+
|
|
3
|
+
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
|
1
4
|
from langchain_groq import ChatGroq
|
|
2
5
|
|
|
3
6
|
from kiln_ai.adapters.langchain_adapters import LangChainPromptAdapter
|
|
7
|
+
from kiln_ai.adapters.prompt_builders import SimpleChainOfThoughtPromptBuilder
|
|
4
8
|
from kiln_ai.adapters.test_prompt_adaptors import build_test_task
|
|
5
9
|
|
|
6
10
|
|
|
@@ -49,3 +53,72 @@ def test_langchain_adapter_info(tmp_path):
|
|
|
49
53
|
assert model_info.adapter_name == "kiln_langchain_adapter"
|
|
50
54
|
assert model_info.model_name == "llama_3_1_8b"
|
|
51
55
|
assert model_info.model_provider == "ollama"
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
async def test_langchain_adapter_with_cot(tmp_path):
|
|
59
|
+
task = build_test_task(tmp_path)
|
|
60
|
+
task.output_json_schema = (
|
|
61
|
+
'{"type": "object", "properties": {"count": {"type": "integer"}}}'
|
|
62
|
+
)
|
|
63
|
+
lca = LangChainPromptAdapter(
|
|
64
|
+
kiln_task=task,
|
|
65
|
+
model_name="llama_3_1_8b",
|
|
66
|
+
provider="ollama",
|
|
67
|
+
prompt_builder=SimpleChainOfThoughtPromptBuilder(task),
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
# Mock the base model and its invoke method
|
|
71
|
+
mock_base_model = MagicMock()
|
|
72
|
+
mock_base_model.invoke.return_value = AIMessage(
|
|
73
|
+
content="Chain of thought reasoning..."
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
# Create a separate mock for self.model()
|
|
77
|
+
mock_model_instance = MagicMock()
|
|
78
|
+
mock_model_instance.invoke.return_value = {"parsed": {"count": 1}}
|
|
79
|
+
|
|
80
|
+
# Mock the langchain_model_from function to return the base model
|
|
81
|
+
mock_model_from = AsyncMock(return_value=mock_base_model)
|
|
82
|
+
|
|
83
|
+
# Patch both the langchain_model_from function and self.model()
|
|
84
|
+
with (
|
|
85
|
+
patch(
|
|
86
|
+
"kiln_ai.adapters.langchain_adapters.langchain_model_from", mock_model_from
|
|
87
|
+
),
|
|
88
|
+
patch.object(LangChainPromptAdapter, "model", return_value=mock_model_instance),
|
|
89
|
+
):
|
|
90
|
+
response = await lca._run("test input")
|
|
91
|
+
|
|
92
|
+
# First 3 messages are the same for both calls
|
|
93
|
+
for invoke_args in [
|
|
94
|
+
mock_base_model.invoke.call_args[0][0],
|
|
95
|
+
mock_model_instance.invoke.call_args[0][0],
|
|
96
|
+
]:
|
|
97
|
+
assert isinstance(
|
|
98
|
+
invoke_args[0], SystemMessage
|
|
99
|
+
) # First message should be system prompt
|
|
100
|
+
assert (
|
|
101
|
+
"You are an assistant which performs math tasks provided in plain text."
|
|
102
|
+
in invoke_args[0].content
|
|
103
|
+
)
|
|
104
|
+
assert isinstance(invoke_args[1], HumanMessage)
|
|
105
|
+
assert "test input" in invoke_args[1].content
|
|
106
|
+
assert isinstance(invoke_args[2], SystemMessage)
|
|
107
|
+
assert "step by step" in invoke_args[2].content
|
|
108
|
+
|
|
109
|
+
# the COT should only have 3 messages
|
|
110
|
+
assert len(mock_base_model.invoke.call_args[0][0]) == 3
|
|
111
|
+
assert len(mock_model_instance.invoke.call_args[0][0]) == 5
|
|
112
|
+
|
|
113
|
+
# the final response should have the COT content and the final instructions
|
|
114
|
+
invoke_args = mock_model_instance.invoke.call_args[0][0]
|
|
115
|
+
assert isinstance(invoke_args[3], AIMessage)
|
|
116
|
+
assert "Chain of thought reasoning..." in invoke_args[3].content
|
|
117
|
+
assert isinstance(invoke_args[4], SystemMessage)
|
|
118
|
+
assert "Considering the above, return a final result." in invoke_args[4].content
|
|
119
|
+
|
|
120
|
+
assert (
|
|
121
|
+
response.intermediate_outputs["chain_of_thought"]
|
|
122
|
+
== "Chain of thought reasoning..."
|
|
123
|
+
)
|
|
124
|
+
assert response.output == {"count": 1}
|
|
@@ -4,9 +4,11 @@ from unittest.mock import patch
|
|
|
4
4
|
import pytest
|
|
5
5
|
|
|
6
6
|
from kiln_ai.adapters.ml_model_list import (
|
|
7
|
+
ModelName,
|
|
7
8
|
ModelProviderName,
|
|
8
9
|
OllamaConnection,
|
|
9
10
|
check_provider_warnings,
|
|
11
|
+
get_model_and_provider,
|
|
10
12
|
ollama_model_supported,
|
|
11
13
|
parse_ollama_tags,
|
|
12
14
|
provider_name_from_id,
|
|
@@ -123,3 +125,57 @@ def test_ollama_model_supported():
|
|
|
123
125
|
assert ollama_model_supported(conn, "llama3.1:latest")
|
|
124
126
|
assert ollama_model_supported(conn, "llama3.1")
|
|
125
127
|
assert not ollama_model_supported(conn, "unknown_model")
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def test_get_model_and_provider_valid():
|
|
131
|
+
# Test with a known valid model and provider combination
|
|
132
|
+
model, provider = get_model_and_provider(
|
|
133
|
+
ModelName.phi_3_5, ModelProviderName.ollama
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
assert model is not None
|
|
137
|
+
assert provider is not None
|
|
138
|
+
assert model.name == ModelName.phi_3_5
|
|
139
|
+
assert provider.name == ModelProviderName.ollama
|
|
140
|
+
assert provider.provider_options["model"] == "phi3.5"
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def test_get_model_and_provider_invalid_model():
|
|
144
|
+
# Test with an invalid model name
|
|
145
|
+
model, provider = get_model_and_provider(
|
|
146
|
+
"nonexistent_model", ModelProviderName.ollama
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
assert model is None
|
|
150
|
+
assert provider is None
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def test_get_model_and_provider_invalid_provider():
|
|
154
|
+
# Test with a valid model but invalid provider
|
|
155
|
+
model, provider = get_model_and_provider(ModelName.phi_3_5, "nonexistent_provider")
|
|
156
|
+
|
|
157
|
+
assert model is None
|
|
158
|
+
assert provider is None
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def test_get_model_and_provider_valid_model_wrong_provider():
|
|
162
|
+
# Test with a valid model but a provider that doesn't support it
|
|
163
|
+
model, provider = get_model_and_provider(
|
|
164
|
+
ModelName.phi_3_5, ModelProviderName.amazon_bedrock
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
assert model is None
|
|
168
|
+
assert provider is None
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def test_get_model_and_provider_multiple_providers():
|
|
172
|
+
# Test with a model that has multiple providers
|
|
173
|
+
model, provider = get_model_and_provider(
|
|
174
|
+
ModelName.llama_3_1_70b, ModelProviderName.groq
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
assert model is not None
|
|
178
|
+
assert provider is not None
|
|
179
|
+
assert model.name == ModelName.llama_3_1_70b
|
|
180
|
+
assert provider.name == ModelProviderName.groq
|
|
181
|
+
assert provider.provider_options["model"] == "llama-3.1-70b-versatile"
|
|
@@ -7,6 +7,18 @@ from langchain_core.language_models.fake_chat_models import FakeListChatModel
|
|
|
7
7
|
import kiln_ai.datamodel as datamodel
|
|
8
8
|
from kiln_ai.adapters.langchain_adapters import LangChainPromptAdapter
|
|
9
9
|
from kiln_ai.adapters.ml_model_list import built_in_models, ollama_online
|
|
10
|
+
from kiln_ai.adapters.prompt_builders import (
|
|
11
|
+
BasePromptBuilder,
|
|
12
|
+
SimpleChainOfThoughtPromptBuilder,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def get_all_models_and_providers():
|
|
17
|
+
model_provider_pairs = []
|
|
18
|
+
for model in built_in_models:
|
|
19
|
+
for provider in model.providers:
|
|
20
|
+
model_provider_pairs.append((model.name, provider.name))
|
|
21
|
+
return model_provider_pairs
|
|
10
22
|
|
|
11
23
|
|
|
12
24
|
@pytest.mark.paid
|
|
@@ -28,6 +40,9 @@ async def test_groq(tmp_path):
|
|
|
28
40
|
"llama_3_2_3b",
|
|
29
41
|
"llama_3_2_11b",
|
|
30
42
|
"llama_3_2_90b",
|
|
43
|
+
"claude_3_5_haiku",
|
|
44
|
+
"claude_3_5_sonnet",
|
|
45
|
+
"phi_3_5",
|
|
31
46
|
],
|
|
32
47
|
)
|
|
33
48
|
@pytest.mark.paid
|
|
@@ -117,15 +132,19 @@ async def test_mock_returning_run(tmp_path):
|
|
|
117
132
|
|
|
118
133
|
@pytest.mark.paid
|
|
119
134
|
@pytest.mark.ollama
|
|
120
|
-
|
|
135
|
+
@pytest.mark.parametrize("model_name,provider_name", get_all_models_and_providers())
|
|
136
|
+
async def test_all_models_providers_plaintext(tmp_path, model_name, provider_name):
|
|
121
137
|
task = build_test_task(tmp_path)
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
138
|
+
await run_simple_task(task, model_name, provider_name)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
@pytest.mark.paid
|
|
142
|
+
@pytest.mark.ollama
|
|
143
|
+
@pytest.mark.parametrize("model_name,provider_name", get_all_models_and_providers())
|
|
144
|
+
async def test_cot_prompt_builder(tmp_path, model_name, provider_name):
|
|
145
|
+
task = build_test_task(tmp_path)
|
|
146
|
+
pb = SimpleChainOfThoughtPromptBuilder(task)
|
|
147
|
+
await run_simple_task(task, model_name, provider_name, pb)
|
|
129
148
|
|
|
130
149
|
|
|
131
150
|
def build_test_task(tmp_path: Path):
|
|
@@ -157,13 +176,25 @@ def build_test_task(tmp_path: Path):
|
|
|
157
176
|
return task
|
|
158
177
|
|
|
159
178
|
|
|
160
|
-
async def run_simple_test(
|
|
179
|
+
async def run_simple_test(
|
|
180
|
+
tmp_path: Path,
|
|
181
|
+
model_name: str,
|
|
182
|
+
provider: str | None = None,
|
|
183
|
+
prompt_builder: BasePromptBuilder | None = None,
|
|
184
|
+
):
|
|
161
185
|
task = build_test_task(tmp_path)
|
|
162
|
-
return await run_simple_task(task, model_name, provider)
|
|
186
|
+
return await run_simple_task(task, model_name, provider, prompt_builder)
|
|
163
187
|
|
|
164
188
|
|
|
165
|
-
async def run_simple_task(
|
|
166
|
-
|
|
189
|
+
async def run_simple_task(
|
|
190
|
+
task: datamodel.Task,
|
|
191
|
+
model_name: str,
|
|
192
|
+
provider: str,
|
|
193
|
+
prompt_builder: BasePromptBuilder | None = None,
|
|
194
|
+
) -> datamodel.TaskRun:
|
|
195
|
+
adapter = LangChainPromptAdapter(
|
|
196
|
+
task, model_name=model_name, provider=provider, prompt_builder=prompt_builder
|
|
197
|
+
)
|
|
167
198
|
|
|
168
199
|
run = await adapter.invoke(
|
|
169
200
|
"You should answer the following question: four plus six times 10"
|
|
@@ -174,9 +205,14 @@ async def run_simple_task(task: datamodel.Task, model_name: str, provider: str):
|
|
|
174
205
|
run.input == "You should answer the following question: four plus six times 10"
|
|
175
206
|
)
|
|
176
207
|
assert "64" in run.output.output
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
208
|
+
source_props = run.output.source.properties
|
|
209
|
+
assert source_props["adapter_name"] == "kiln_langchain_adapter"
|
|
210
|
+
assert source_props["model_name"] == model_name
|
|
211
|
+
assert source_props["model_provider"] == provider
|
|
212
|
+
expected_prompt_builder_name = (
|
|
213
|
+
prompt_builder.__class__.prompt_builder_name()
|
|
214
|
+
if prompt_builder
|
|
215
|
+
else "simple_prompt_builder"
|
|
216
|
+
)
|
|
217
|
+
assert source_props["prompt_builder_name"] == expected_prompt_builder_name
|
|
218
|
+
return run
|