kiln-ai 0.5.0__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/__init__.py +3 -0
- kiln_ai/adapters/__init__.py +23 -0
- kiln_ai/adapters/base_adapter.py +29 -0
- kiln_ai/adapters/langchain_adapters.py +3 -2
- kiln_ai/adapters/ml_model_list.py +89 -4
- kiln_ai/adapters/prompt_builders.py +114 -15
- kiln_ai/adapters/repair/__init__.py +11 -0
- kiln_ai/adapters/repair/repair_task.py +2 -1
- kiln_ai/adapters/repair/test_repair_task.py +2 -1
- kiln_ai/adapters/test_langchain_adapter.py +2 -1
- kiln_ai/adapters/test_ml_model_list.py +2 -2
- kiln_ai/adapters/test_prompt_adaptors.py +3 -2
- kiln_ai/adapters/test_prompt_builders.py +25 -2
- kiln_ai/adapters/test_saving_adapter_results.py +1 -0
- kiln_ai/adapters/test_structured_output.py +2 -1
- kiln_ai/datamodel/__init__.py +81 -29
- kiln_ai/datamodel/basemodel.py +84 -3
- kiln_ai/datamodel/json_schema.py +35 -1
- kiln_ai/datamodel/test_basemodel.py +31 -0
- kiln_ai/datamodel/test_datasource.py +5 -6
- kiln_ai/datamodel/test_example_models.py +11 -40
- kiln_ai/datamodel/test_json_schema.py +2 -1
- kiln_ai/datamodel/test_models.py +2 -1
- kiln_ai/datamodel/test_nested_save.py +2 -1
- kiln_ai/datamodel/test_output_rating.py +2 -1
- kiln_ai/utils/__init__.py +12 -0
- kiln_ai/utils/test_config.py +1 -0
- kiln_ai-0.5.2.dist-info/METADATA +48 -0
- kiln_ai-0.5.2.dist-info/RECORD +33 -0
- {kiln_ai-0.5.0.dist-info → kiln_ai-0.5.2.dist-info}/WHEEL +1 -1
- kiln_ai-0.5.0.dist-info/METADATA +0 -37
- kiln_ai-0.5.0.dist-info/RECORD +0 -29
- {kiln_ai-0.5.0.dist-info → kiln_ai-0.5.2.dist-info/licenses}/LICENSE.txt +0 -0
kiln_ai/__init__.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
"""
|
|
2
|
+
# Adapters
|
|
3
|
+
|
|
4
|
+
Adapters are used to connect Kiln to external systems, or to add new functionality to Kiln.
|
|
5
|
+
|
|
6
|
+
BaseAdapter is extensible, and used for adding adapters that provide AI functionality. There's currently a LangChain adapter which provides a bridge to LangChain.
|
|
7
|
+
|
|
8
|
+
The ml_model_list submodule contains a list of models that can be used for machine learning tasks. More can easily be added, but we keep a list here of models that are known to work well with Kiln's structured data and tool calling systems.
|
|
9
|
+
|
|
10
|
+
The prompt_builders submodule contains classes that build prompts for use with the AI agents.
|
|
11
|
+
|
|
12
|
+
The repair submodule contains an adapter for the repair task.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from . import base_adapter, langchain_adapters, ml_model_list, prompt_builders, repair
|
|
16
|
+
|
|
17
|
+
__all__ = [
|
|
18
|
+
"base_adapter",
|
|
19
|
+
"langchain_adapters",
|
|
20
|
+
"ml_model_list",
|
|
21
|
+
"prompt_builders",
|
|
22
|
+
"repair",
|
|
23
|
+
]
|
kiln_ai/adapters/base_adapter.py
CHANGED
|
@@ -25,6 +25,35 @@ class AdapterInfo:
|
|
|
25
25
|
|
|
26
26
|
|
|
27
27
|
class BaseAdapter(metaclass=ABCMeta):
|
|
28
|
+
"""Base class for AI model adapters that handle task execution.
|
|
29
|
+
|
|
30
|
+
This abstract class provides the foundation for implementing model-specific adapters
|
|
31
|
+
that can process tasks with structured or unstructured inputs/outputs. It handles
|
|
32
|
+
input/output validation, prompt building, and run tracking.
|
|
33
|
+
|
|
34
|
+
Attributes:
|
|
35
|
+
prompt_builder (BasePromptBuilder): Builder for constructing prompts for the model
|
|
36
|
+
kiln_task (Task): The task configuration and metadata
|
|
37
|
+
output_schema (dict | None): JSON schema for validating structured outputs
|
|
38
|
+
input_schema (dict | None): JSON schema for validating structured inputs
|
|
39
|
+
|
|
40
|
+
Example:
|
|
41
|
+
```python
|
|
42
|
+
class CustomAdapter(BaseAdapter):
|
|
43
|
+
async def _run(self, input: Dict | str) -> Dict | str:
|
|
44
|
+
# Implementation for specific model
|
|
45
|
+
pass
|
|
46
|
+
|
|
47
|
+
def adapter_info(self) -> AdapterInfo:
|
|
48
|
+
return AdapterInfo(
|
|
49
|
+
adapter_name="custom",
|
|
50
|
+
model_name="model-1",
|
|
51
|
+
model_provider="provider",
|
|
52
|
+
prompt_builder_name="simple"
|
|
53
|
+
)
|
|
54
|
+
```
|
|
55
|
+
"""
|
|
56
|
+
|
|
28
57
|
def __init__(
|
|
29
58
|
self, kiln_task: Task, prompt_builder: BasePromptBuilder | None = None
|
|
30
59
|
):
|
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
from typing import Dict
|
|
2
2
|
|
|
3
|
-
import kiln_ai.datamodel as datamodel
|
|
4
|
-
from kiln_ai.adapters.prompt_builders import SimplePromptBuilder
|
|
5
3
|
from langchain_core.language_models.chat_models import BaseChatModel
|
|
6
4
|
from langchain_core.messages import HumanMessage, SystemMessage
|
|
7
5
|
from langchain_core.messages.base import BaseMessage
|
|
8
6
|
|
|
7
|
+
import kiln_ai.datamodel as datamodel
|
|
8
|
+
|
|
9
9
|
from .base_adapter import AdapterInfo, BaseAdapter, BasePromptBuilder
|
|
10
10
|
from .ml_model_list import langchain_model_from
|
|
11
11
|
|
|
@@ -77,6 +77,7 @@ class LangChainPromptAdapter(BaseAdapter):
|
|
|
77
77
|
HumanMessage(content=user_msg),
|
|
78
78
|
]
|
|
79
79
|
response = self.model.invoke(messages)
|
|
80
|
+
|
|
80
81
|
if self.has_structured_output():
|
|
81
82
|
if (
|
|
82
83
|
not isinstance(response, dict)
|
|
@@ -14,8 +14,18 @@ from pydantic import BaseModel
|
|
|
14
14
|
|
|
15
15
|
from ..utils.config import Config
|
|
16
16
|
|
|
17
|
+
"""
|
|
18
|
+
Provides model configuration and management for various LLM providers and models.
|
|
19
|
+
This module handles the integration with different AI model providers and their respective models,
|
|
20
|
+
including configuration, validation, and instantiation of language models.
|
|
21
|
+
"""
|
|
22
|
+
|
|
17
23
|
|
|
18
24
|
class ModelProviderName(str, Enum):
|
|
25
|
+
"""
|
|
26
|
+
Enumeration of supported AI model providers.
|
|
27
|
+
"""
|
|
28
|
+
|
|
19
29
|
openai = "openai"
|
|
20
30
|
groq = "groq"
|
|
21
31
|
amazon_bedrock = "amazon_bedrock"
|
|
@@ -24,6 +34,10 @@ class ModelProviderName(str, Enum):
|
|
|
24
34
|
|
|
25
35
|
|
|
26
36
|
class ModelFamily(str, Enum):
|
|
37
|
+
"""
|
|
38
|
+
Enumeration of supported model families/architectures.
|
|
39
|
+
"""
|
|
40
|
+
|
|
27
41
|
gpt = "gpt"
|
|
28
42
|
llama = "llama"
|
|
29
43
|
phi = "phi"
|
|
@@ -33,6 +47,11 @@ class ModelFamily(str, Enum):
|
|
|
33
47
|
|
|
34
48
|
# Where models have instruct and raw versions, instruct is default and raw is specified
|
|
35
49
|
class ModelName(str, Enum):
|
|
50
|
+
"""
|
|
51
|
+
Enumeration of specific model versions supported by the system.
|
|
52
|
+
Where models have instruct and raw versions, instruct is default and raw is specified.
|
|
53
|
+
"""
|
|
54
|
+
|
|
36
55
|
llama_3_1_8b = "llama_3_1_8b"
|
|
37
56
|
llama_3_1_70b = "llama_3_1_70b"
|
|
38
57
|
llama_3_1_405b = "llama_3_1_405b"
|
|
@@ -47,13 +66,32 @@ class ModelName(str, Enum):
|
|
|
47
66
|
|
|
48
67
|
|
|
49
68
|
class KilnModelProvider(BaseModel):
|
|
69
|
+
"""
|
|
70
|
+
Configuration for a specific model provider.
|
|
71
|
+
|
|
72
|
+
Attributes:
|
|
73
|
+
name: The provider's identifier
|
|
74
|
+
supports_structured_output: Whether the provider supports structured output formats
|
|
75
|
+
provider_options: Additional provider-specific configuration options
|
|
76
|
+
"""
|
|
77
|
+
|
|
50
78
|
name: ModelProviderName
|
|
51
|
-
# Allow overriding the model level setting
|
|
52
79
|
supports_structured_output: bool = True
|
|
53
80
|
provider_options: Dict = {}
|
|
54
81
|
|
|
55
82
|
|
|
56
83
|
class KilnModel(BaseModel):
|
|
84
|
+
"""
|
|
85
|
+
Configuration for a specific AI model.
|
|
86
|
+
|
|
87
|
+
Attributes:
|
|
88
|
+
family: The model's architecture family
|
|
89
|
+
name: The model's identifier
|
|
90
|
+
friendly_name: Human-readable name for the model
|
|
91
|
+
providers: List of providers that offer this model
|
|
92
|
+
supports_structured_output: Whether the model supports structured output formats
|
|
93
|
+
"""
|
|
94
|
+
|
|
57
95
|
family: str
|
|
58
96
|
name: str
|
|
59
97
|
friendly_name: str
|
|
@@ -292,6 +330,18 @@ built_in_models: List[KilnModel] = [
|
|
|
292
330
|
|
|
293
331
|
|
|
294
332
|
def provider_name_from_id(id: str) -> str:
|
|
333
|
+
"""
|
|
334
|
+
Converts a provider ID to its human-readable name.
|
|
335
|
+
|
|
336
|
+
Args:
|
|
337
|
+
id: The provider identifier string
|
|
338
|
+
|
|
339
|
+
Returns:
|
|
340
|
+
The human-readable name of the provider
|
|
341
|
+
|
|
342
|
+
Raises:
|
|
343
|
+
ValueError: If the provider ID is invalid or unhandled
|
|
344
|
+
"""
|
|
295
345
|
if id in ModelProviderName.__members__:
|
|
296
346
|
enum_id = ModelProviderName(id)
|
|
297
347
|
match enum_id:
|
|
@@ -350,6 +400,15 @@ def get_config_value(key: str):
|
|
|
350
400
|
|
|
351
401
|
|
|
352
402
|
def check_provider_warnings(provider_name: ModelProviderName):
|
|
403
|
+
"""
|
|
404
|
+
Validates that required configuration is present for a given provider.
|
|
405
|
+
|
|
406
|
+
Args:
|
|
407
|
+
provider_name: The provider to check
|
|
408
|
+
|
|
409
|
+
Raises:
|
|
410
|
+
ValueError: If required configuration keys are missing
|
|
411
|
+
"""
|
|
353
412
|
warning_check = provider_warnings.get(provider_name)
|
|
354
413
|
if warning_check is None:
|
|
355
414
|
return
|
|
@@ -359,6 +418,19 @@ def check_provider_warnings(provider_name: ModelProviderName):
|
|
|
359
418
|
|
|
360
419
|
|
|
361
420
|
def langchain_model_from(name: str, provider_name: str | None = None) -> BaseChatModel:
|
|
421
|
+
"""
|
|
422
|
+
Creates a LangChain chat model instance for the specified model and provider.
|
|
423
|
+
|
|
424
|
+
Args:
|
|
425
|
+
name: The name of the model to instantiate
|
|
426
|
+
provider_name: Optional specific provider to use (defaults to first available)
|
|
427
|
+
|
|
428
|
+
Returns:
|
|
429
|
+
A configured LangChain chat model instance
|
|
430
|
+
|
|
431
|
+
Raises:
|
|
432
|
+
ValueError: If the model/provider combination is invalid or misconfigured
|
|
433
|
+
"""
|
|
362
434
|
if name not in ModelName.__members__:
|
|
363
435
|
raise ValueError(f"Invalid name: {name}")
|
|
364
436
|
|
|
@@ -413,7 +485,7 @@ def langchain_model_from(name: str, provider_name: str | None = None) -> BaseCha
|
|
|
413
485
|
openai_api_key=api_key, # type: ignore[arg-type]
|
|
414
486
|
openai_api_base=base_url, # type: ignore[arg-type]
|
|
415
487
|
default_headers={
|
|
416
|
-
"HTTP-Referer": "https://
|
|
488
|
+
"HTTP-Referer": "https://getkiln.ai/openrouter",
|
|
417
489
|
"X-Title": "KilnAI",
|
|
418
490
|
},
|
|
419
491
|
)
|
|
@@ -421,14 +493,27 @@ def langchain_model_from(name: str, provider_name: str | None = None) -> BaseCha
|
|
|
421
493
|
raise ValueError(f"Invalid model or provider: {name} - {provider_name}")
|
|
422
494
|
|
|
423
495
|
|
|
424
|
-
def ollama_base_url():
|
|
496
|
+
def ollama_base_url() -> str:
|
|
497
|
+
"""
|
|
498
|
+
Gets the base URL for Ollama API connections.
|
|
499
|
+
|
|
500
|
+
Returns:
|
|
501
|
+
The base URL to use for Ollama API calls, using environment variable if set
|
|
502
|
+
or falling back to localhost default
|
|
503
|
+
"""
|
|
425
504
|
env_base_url = os.getenv("OLLAMA_BASE_URL")
|
|
426
505
|
if env_base_url is not None:
|
|
427
506
|
return env_base_url
|
|
428
507
|
return "http://localhost:11434"
|
|
429
508
|
|
|
430
509
|
|
|
431
|
-
async def ollama_online():
|
|
510
|
+
async def ollama_online() -> bool:
|
|
511
|
+
"""
|
|
512
|
+
Checks if the Ollama service is available and responding.
|
|
513
|
+
|
|
514
|
+
Returns:
|
|
515
|
+
True if Ollama is available and responding, False otherwise
|
|
516
|
+
"""
|
|
432
517
|
try:
|
|
433
518
|
httpx.get(ollama_base_url() + "/api/tags")
|
|
434
519
|
except httpx.RequestError:
|
|
@@ -2,25 +2,53 @@ import json
|
|
|
2
2
|
from abc import ABCMeta, abstractmethod
|
|
3
3
|
from typing import Dict
|
|
4
4
|
|
|
5
|
-
from kiln_ai.datamodel import Task
|
|
5
|
+
from kiln_ai.datamodel import Task, TaskRun
|
|
6
6
|
from kiln_ai.utils.formatting import snake_case
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
class BasePromptBuilder(metaclass=ABCMeta):
|
|
10
|
+
"""Base class for building prompts from tasks.
|
|
11
|
+
|
|
12
|
+
Provides the core interface and basic functionality for prompt builders.
|
|
13
|
+
"""
|
|
14
|
+
|
|
10
15
|
def __init__(self, task: Task):
|
|
16
|
+
"""Initialize the prompt builder with a task.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
task (Task): The task containing instructions and requirements.
|
|
20
|
+
"""
|
|
11
21
|
self.task = task
|
|
12
22
|
|
|
13
23
|
@abstractmethod
|
|
14
24
|
def build_prompt(self) -> str:
|
|
25
|
+
"""Build and return the complete prompt string.
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
str: The constructed prompt.
|
|
29
|
+
"""
|
|
15
30
|
pass
|
|
16
31
|
|
|
17
|
-
# override to change the name of the prompt builder (if changing class names)
|
|
18
32
|
@classmethod
|
|
19
33
|
def prompt_builder_name(cls) -> str:
|
|
34
|
+
"""Returns the name of the prompt builder, to be used for persisting into the datastore.
|
|
35
|
+
|
|
36
|
+
Default implementation gets the name of the prompt builder in snake case. If you change the class name, you should override this so prior saved data is compatible.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
str: The prompt builder name in snake_case format.
|
|
40
|
+
"""
|
|
20
41
|
return snake_case(cls.__name__)
|
|
21
42
|
|
|
22
|
-
# Can be overridden to add more information to the user message
|
|
23
43
|
def build_user_message(self, input: Dict | str) -> str:
|
|
44
|
+
"""Build a user message from the input.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
input (Union[Dict, str]): The input to format into a message.
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
str: The formatted user message.
|
|
51
|
+
"""
|
|
24
52
|
if isinstance(input, Dict):
|
|
25
53
|
return f"The input is:\n{json.dumps(input, indent=2)}"
|
|
26
54
|
|
|
@@ -28,7 +56,14 @@ class BasePromptBuilder(metaclass=ABCMeta):
|
|
|
28
56
|
|
|
29
57
|
|
|
30
58
|
class SimplePromptBuilder(BasePromptBuilder):
|
|
59
|
+
"""A basic prompt builder that combines task instruction with requirements."""
|
|
60
|
+
|
|
31
61
|
def build_prompt(self) -> str:
|
|
62
|
+
"""Build a simple prompt with instruction and requirements.
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
str: The constructed prompt string.
|
|
66
|
+
"""
|
|
32
67
|
base_prompt = self.task.instruction
|
|
33
68
|
|
|
34
69
|
# TODO: this is just a quick version. Formatting and best practices TBD
|
|
@@ -44,11 +79,23 @@ class SimplePromptBuilder(BasePromptBuilder):
|
|
|
44
79
|
|
|
45
80
|
|
|
46
81
|
class MultiShotPromptBuilder(BasePromptBuilder):
|
|
82
|
+
"""A prompt builder that includes multiple examples in the prompt."""
|
|
83
|
+
|
|
47
84
|
@classmethod
|
|
48
85
|
def example_count(cls) -> int:
|
|
86
|
+
"""Get the maximum number of examples to include in the prompt.
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
int: The maximum number of examples (default 25).
|
|
90
|
+
"""
|
|
49
91
|
return 25
|
|
50
92
|
|
|
51
93
|
def build_prompt(self) -> str:
|
|
94
|
+
"""Build a prompt with instruction, requirements, and multiple examples.
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
str: The constructed prompt string with examples.
|
|
98
|
+
"""
|
|
52
99
|
base_prompt = f"# Instruction\n\n{ self.task.instruction }\n\n"
|
|
53
100
|
|
|
54
101
|
if len(self.task.requirements) > 0:
|
|
@@ -57,7 +104,24 @@ class MultiShotPromptBuilder(BasePromptBuilder):
|
|
|
57
104
|
base_prompt += f"{i+1}) {requirement.instruction}\n"
|
|
58
105
|
base_prompt += "\n"
|
|
59
106
|
|
|
60
|
-
valid_examples
|
|
107
|
+
valid_examples = self.collect_examples()
|
|
108
|
+
|
|
109
|
+
if len(valid_examples) == 0:
|
|
110
|
+
return base_prompt
|
|
111
|
+
|
|
112
|
+
base_prompt += "# Example Outputs\n\n"
|
|
113
|
+
for i, example in enumerate(valid_examples):
|
|
114
|
+
base_prompt += self.prompt_section_for_example(i, example)
|
|
115
|
+
|
|
116
|
+
return base_prompt
|
|
117
|
+
|
|
118
|
+
def prompt_section_for_example(self, index: int, example: TaskRun) -> str:
|
|
119
|
+
# Prefer repaired output if it exists, otherwise use the regular output
|
|
120
|
+
output = example.repaired_output or example.output
|
|
121
|
+
return f"## Example {index+1}\n\nInput: {example.input}\nOutput: {output.output}\n\n"
|
|
122
|
+
|
|
123
|
+
def collect_examples(self) -> list[TaskRun]:
|
|
124
|
+
valid_examples: list[TaskRun] = []
|
|
61
125
|
runs = self.task.runs()
|
|
62
126
|
|
|
63
127
|
# first pass, we look for repaired outputs. These are the best examples.
|
|
@@ -65,7 +129,7 @@ class MultiShotPromptBuilder(BasePromptBuilder):
|
|
|
65
129
|
if len(valid_examples) >= self.__class__.example_count():
|
|
66
130
|
break
|
|
67
131
|
if run.repaired_output is not None:
|
|
68
|
-
valid_examples.append(
|
|
132
|
+
valid_examples.append(run)
|
|
69
133
|
|
|
70
134
|
# second pass, we look for high quality outputs (rating based)
|
|
71
135
|
# Minimum is "high_quality" (4 star in star rating scale), then sort by rating
|
|
@@ -84,33 +148,66 @@ class MultiShotPromptBuilder(BasePromptBuilder):
|
|
|
84
148
|
for run in runs_with_rating:
|
|
85
149
|
if len(valid_examples) >= self.__class__.example_count():
|
|
86
150
|
break
|
|
87
|
-
valid_examples.append(
|
|
88
|
-
|
|
89
|
-
if len(valid_examples) > 0:
|
|
90
|
-
base_prompt += "# Example Outputs\n\n"
|
|
91
|
-
for i, example in enumerate(valid_examples):
|
|
92
|
-
base_prompt += (
|
|
93
|
-
f"## Example {i+1}\n\nInput: {example[0]}\nOutput: {example[1]}\n\n"
|
|
94
|
-
)
|
|
95
|
-
|
|
96
|
-
return base_prompt
|
|
151
|
+
valid_examples.append(run)
|
|
152
|
+
return valid_examples
|
|
97
153
|
|
|
98
154
|
|
|
99
155
|
class FewShotPromptBuilder(MultiShotPromptBuilder):
|
|
156
|
+
"""A prompt builder that includes a small number of examples in the prompt."""
|
|
157
|
+
|
|
100
158
|
@classmethod
|
|
101
159
|
def example_count(cls) -> int:
|
|
160
|
+
"""Get the maximum number of examples to include in the prompt.
|
|
161
|
+
|
|
162
|
+
Returns:
|
|
163
|
+
int: The maximum number of examples (4).
|
|
164
|
+
"""
|
|
102
165
|
return 4
|
|
103
166
|
|
|
104
167
|
|
|
168
|
+
class RepairsPromptBuilder(MultiShotPromptBuilder):
|
|
169
|
+
"""A prompt builder that includes multiple examples in the prompt, including repaired instructions describing what was wrong, and how it was fixed."""
|
|
170
|
+
|
|
171
|
+
def prompt_section_for_example(self, index: int, example: TaskRun) -> str:
|
|
172
|
+
if (
|
|
173
|
+
not example.repaired_output
|
|
174
|
+
or not example.repair_instructions
|
|
175
|
+
or not example.repaired_output.output
|
|
176
|
+
):
|
|
177
|
+
return super().prompt_section_for_example(index, example)
|
|
178
|
+
|
|
179
|
+
prompt_section = f"## Example {index+1}\n\nInput: {example.input}\n\n"
|
|
180
|
+
prompt_section += (
|
|
181
|
+
f"Initial Output Which Was Insufficient: {example.output.output}\n\n"
|
|
182
|
+
)
|
|
183
|
+
prompt_section += f"Instructions On How to Improve the Initial Output: {example.repair_instructions}\n\n"
|
|
184
|
+
prompt_section += (
|
|
185
|
+
f"Repaired Output Which is Sufficient: {example.repaired_output.output}\n\n"
|
|
186
|
+
)
|
|
187
|
+
return prompt_section
|
|
188
|
+
|
|
189
|
+
|
|
105
190
|
prompt_builder_registry = {
|
|
106
191
|
"simple_prompt_builder": SimplePromptBuilder,
|
|
107
192
|
"multi_shot_prompt_builder": MultiShotPromptBuilder,
|
|
108
193
|
"few_shot_prompt_builder": FewShotPromptBuilder,
|
|
194
|
+
"repairs_prompt_builder": RepairsPromptBuilder,
|
|
109
195
|
}
|
|
110
196
|
|
|
111
197
|
|
|
112
198
|
# Our UI has some names that are not the same as the class names, which also hint parameters.
|
|
113
199
|
def prompt_builder_from_ui_name(ui_name: str) -> type[BasePromptBuilder]:
|
|
200
|
+
"""Convert a name used in the UI to the corresponding prompt builder class.
|
|
201
|
+
|
|
202
|
+
Args:
|
|
203
|
+
ui_name (str): The UI name for the prompt builder type.
|
|
204
|
+
|
|
205
|
+
Returns:
|
|
206
|
+
type[BasePromptBuilder]: The corresponding prompt builder class.
|
|
207
|
+
|
|
208
|
+
Raises:
|
|
209
|
+
ValueError: If the UI name is not recognized.
|
|
210
|
+
"""
|
|
114
211
|
match ui_name:
|
|
115
212
|
case "basic":
|
|
116
213
|
return SimplePromptBuilder
|
|
@@ -118,5 +215,7 @@ def prompt_builder_from_ui_name(ui_name: str) -> type[BasePromptBuilder]:
|
|
|
118
215
|
return FewShotPromptBuilder
|
|
119
216
|
case "many_shot":
|
|
120
217
|
return MultiShotPromptBuilder
|
|
218
|
+
case "repairs":
|
|
219
|
+
return RepairsPromptBuilder
|
|
121
220
|
case _:
|
|
122
221
|
raise ValueError(f"Unknown prompt builder: {ui_name}")
|
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
import json
|
|
2
2
|
from typing import Type
|
|
3
3
|
|
|
4
|
+
from pydantic import BaseModel, Field
|
|
5
|
+
|
|
4
6
|
from kiln_ai.adapters.prompt_builders import BasePromptBuilder, prompt_builder_registry
|
|
5
7
|
from kiln_ai.datamodel import Priority, Project, Task, TaskRequirement, TaskRun
|
|
6
|
-
from pydantic import BaseModel, Field
|
|
7
8
|
|
|
8
9
|
|
|
9
10
|
# TODO add evaluator rating
|
|
@@ -3,6 +3,8 @@ import os
|
|
|
3
3
|
from unittest.mock import AsyncMock, patch
|
|
4
4
|
|
|
5
5
|
import pytest
|
|
6
|
+
from pydantic import ValidationError
|
|
7
|
+
|
|
6
8
|
from kiln_ai.adapters.langchain_adapters import (
|
|
7
9
|
LangChainPromptAdapter,
|
|
8
10
|
)
|
|
@@ -19,7 +21,6 @@ from kiln_ai.datamodel import (
|
|
|
19
21
|
TaskRequirement,
|
|
20
22
|
TaskRun,
|
|
21
23
|
)
|
|
22
|
-
from pydantic import ValidationError
|
|
23
24
|
|
|
24
25
|
json_joke_schema = """{
|
|
25
26
|
"type": "object",
|
|
@@ -1,6 +1,7 @@
|
|
|
1
|
+
from langchain_groq import ChatGroq
|
|
2
|
+
|
|
1
3
|
from kiln_ai.adapters.langchain_adapters import LangChainPromptAdapter
|
|
2
4
|
from kiln_ai.adapters.test_prompt_adaptors import build_test_task
|
|
3
|
-
from langchain_groq import ChatGroq
|
|
4
5
|
|
|
5
6
|
|
|
6
7
|
def test_langchain_adapter_munge_response(tmp_path):
|
|
@@ -2,7 +2,7 @@ from unittest.mock import patch
|
|
|
2
2
|
|
|
3
3
|
import pytest
|
|
4
4
|
|
|
5
|
-
from
|
|
5
|
+
from kiln_ai.adapters.ml_model_list import (
|
|
6
6
|
ModelProviderName,
|
|
7
7
|
check_provider_warnings,
|
|
8
8
|
provider_name_from_id,
|
|
@@ -12,7 +12,7 @@ from libs.core.kiln_ai.adapters.ml_model_list import (
|
|
|
12
12
|
|
|
13
13
|
@pytest.fixture
|
|
14
14
|
def mock_config():
|
|
15
|
-
with patch("
|
|
15
|
+
with patch("kiln_ai.adapters.ml_model_list.get_config_value") as mock:
|
|
16
16
|
yield mock
|
|
17
17
|
|
|
18
18
|
|
|
@@ -1,11 +1,12 @@
|
|
|
1
1
|
import os
|
|
2
2
|
from pathlib import Path
|
|
3
3
|
|
|
4
|
-
import kiln_ai.datamodel as datamodel
|
|
5
4
|
import pytest
|
|
5
|
+
from langchain_core.language_models.fake_chat_models import FakeListChatModel
|
|
6
|
+
|
|
7
|
+
import kiln_ai.datamodel as datamodel
|
|
6
8
|
from kiln_ai.adapters.langchain_adapters import LangChainPromptAdapter
|
|
7
9
|
from kiln_ai.adapters.ml_model_list import built_in_models, ollama_online
|
|
8
|
-
from langchain_core.language_models.fake_chat_models import FakeListChatModel
|
|
9
10
|
|
|
10
11
|
|
|
11
12
|
@pytest.mark.paid
|
|
@@ -1,10 +1,12 @@
|
|
|
1
1
|
import json
|
|
2
2
|
|
|
3
3
|
import pytest
|
|
4
|
+
|
|
4
5
|
from kiln_ai.adapters.base_adapter import AdapterInfo, BaseAdapter
|
|
5
6
|
from kiln_ai.adapters.prompt_builders import (
|
|
6
7
|
FewShotPromptBuilder,
|
|
7
8
|
MultiShotPromptBuilder,
|
|
9
|
+
RepairsPromptBuilder,
|
|
8
10
|
SimplePromptBuilder,
|
|
9
11
|
prompt_builder_from_ui_name,
|
|
10
12
|
)
|
|
@@ -71,7 +73,8 @@ def test_simple_prompt_builder_structured_output(tmp_path):
|
|
|
71
73
|
assert input not in prompt
|
|
72
74
|
|
|
73
75
|
|
|
74
|
-
|
|
76
|
+
@pytest.fixture
|
|
77
|
+
def task_with_examples(tmp_path):
|
|
75
78
|
# Create a project and task hierarchy
|
|
76
79
|
project = Project(name="Test Project", path=(tmp_path / "test_project.kiln"))
|
|
77
80
|
project.save_to_file()
|
|
@@ -192,9 +195,12 @@ def test_multi_shot_prompt_builder(tmp_path):
|
|
|
192
195
|
)
|
|
193
196
|
e3.save_to_file()
|
|
194
197
|
check_example_outputs(task, 3)
|
|
198
|
+
return task
|
|
199
|
+
|
|
195
200
|
|
|
201
|
+
def test_multi_shot_prompt_builder(task_with_examples):
|
|
196
202
|
# Verify the order of examples
|
|
197
|
-
prompt_builder = MultiShotPromptBuilder(task=
|
|
203
|
+
prompt_builder = MultiShotPromptBuilder(task=task_with_examples)
|
|
198
204
|
prompt = prompt_builder.build_prompt()
|
|
199
205
|
assert "Why did the cow cross the road?" in prompt
|
|
200
206
|
assert prompt.index("Why did the cow cross the road?") < prompt.index(
|
|
@@ -299,12 +305,14 @@ def check_example_outputs(task: Task, count: int):
|
|
|
299
305
|
def test_prompt_builder_name():
|
|
300
306
|
assert SimplePromptBuilder.prompt_builder_name() == "simple_prompt_builder"
|
|
301
307
|
assert MultiShotPromptBuilder.prompt_builder_name() == "multi_shot_prompt_builder"
|
|
308
|
+
assert RepairsPromptBuilder.prompt_builder_name() == "repairs_prompt_builder"
|
|
302
309
|
|
|
303
310
|
|
|
304
311
|
def test_prompt_builder_from_ui_name():
|
|
305
312
|
assert prompt_builder_from_ui_name("basic") == SimplePromptBuilder
|
|
306
313
|
assert prompt_builder_from_ui_name("few_shot") == FewShotPromptBuilder
|
|
307
314
|
assert prompt_builder_from_ui_name("many_shot") == MultiShotPromptBuilder
|
|
315
|
+
assert prompt_builder_from_ui_name("repairs") == RepairsPromptBuilder
|
|
308
316
|
|
|
309
317
|
with pytest.raises(ValueError, match="Unknown prompt builder: invalid_name"):
|
|
310
318
|
prompt_builder_from_ui_name("invalid_name")
|
|
@@ -313,3 +321,18 @@ def test_prompt_builder_from_ui_name():
|
|
|
313
321
|
def test_example_count():
|
|
314
322
|
assert FewShotPromptBuilder.example_count() == 4
|
|
315
323
|
assert MultiShotPromptBuilder.example_count() == 25
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
def test_repair_multi_shot_prompt_builder(task_with_examples):
|
|
327
|
+
# Verify the order of examples
|
|
328
|
+
prompt_builder = RepairsPromptBuilder(task=task_with_examples)
|
|
329
|
+
prompt = prompt_builder.build_prompt()
|
|
330
|
+
assert (
|
|
331
|
+
'Repaired Output Which is Sufficient: {"joke": "Why did the cow cross the road? To get to the udder side!"}'
|
|
332
|
+
in prompt
|
|
333
|
+
)
|
|
334
|
+
assert "Instructions On How to Improve the Initial Output: Fix the joke" in prompt
|
|
335
|
+
assert (
|
|
336
|
+
'Initial Output Which Was Insufficient: {"joke": "Moo I am a cow joke."}'
|
|
337
|
+
in prompt
|
|
338
|
+
)
|
|
@@ -3,8 +3,9 @@ from typing import Dict
|
|
|
3
3
|
|
|
4
4
|
import jsonschema
|
|
5
5
|
import jsonschema.exceptions
|
|
6
|
-
import kiln_ai.datamodel as datamodel
|
|
7
6
|
import pytest
|
|
7
|
+
|
|
8
|
+
import kiln_ai.datamodel as datamodel
|
|
8
9
|
from kiln_ai.adapters.base_adapter import AdapterInfo, BaseAdapter
|
|
9
10
|
from kiln_ai.adapters.langchain_adapters import LangChainPromptAdapter
|
|
10
11
|
from kiln_ai.adapters.ml_model_list import (
|