kiln-ai 0.5.4__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/base_adapter.py +24 -35
- kiln_ai/adapters/data_gen/data_gen_prompts.py +73 -0
- kiln_ai/adapters/data_gen/data_gen_task.py +117 -0
- kiln_ai/adapters/data_gen/test_data_gen_task.py +292 -0
- kiln_ai/adapters/langchain_adapters.py +39 -7
- kiln_ai/adapters/ml_model_list.py +68 -1
- kiln_ai/adapters/prompt_builders.py +66 -0
- kiln_ai/adapters/repair/test_repair_task.py +4 -1
- kiln_ai/adapters/test_langchain_adapter.py +73 -0
- kiln_ai/adapters/test_ml_model_list.py +56 -0
- kiln_ai/adapters/test_prompt_adaptors.py +54 -18
- kiln_ai/adapters/test_prompt_builders.py +97 -7
- kiln_ai/adapters/test_saving_adapter_results.py +16 -6
- kiln_ai/adapters/test_structured_output.py +33 -5
- kiln_ai/datamodel/__init__.py +28 -7
- kiln_ai/datamodel/json_schema.py +1 -0
- kiln_ai/datamodel/test_models.py +44 -8
- kiln_ai/utils/config.py +3 -2
- kiln_ai/utils/test_config.py +7 -0
- {kiln_ai-0.5.4.dist-info → kiln_ai-0.6.0.dist-info}/METADATA +41 -7
- kiln_ai-0.6.0.dist-info/RECORD +36 -0
- {kiln_ai-0.5.4.dist-info → kiln_ai-0.6.0.dist-info}/WHEEL +1 -1
- kiln_ai-0.5.4.dist-info/RECORD +0 -33
- {kiln_ai-0.5.4.dist-info → kiln_ai-0.6.0.dist-info}/licenses/LICENSE.txt +0 -0
kiln_ai/adapters/base_adapter.py
CHANGED
|
@@ -24,6 +24,12 @@ class AdapterInfo:
|
|
|
24
24
|
prompt_builder_name: str
|
|
25
25
|
|
|
26
26
|
|
|
27
|
+
@dataclass
|
|
28
|
+
class RunOutput:
|
|
29
|
+
output: Dict | str
|
|
30
|
+
intermediate_outputs: Dict[str, str] | None
|
|
31
|
+
|
|
32
|
+
|
|
27
33
|
class BaseAdapter(metaclass=ABCMeta):
|
|
28
34
|
"""Base class for AI model adapters that handle task execution.
|
|
29
35
|
|
|
@@ -36,22 +42,6 @@ class BaseAdapter(metaclass=ABCMeta):
|
|
|
36
42
|
kiln_task (Task): The task configuration and metadata
|
|
37
43
|
output_schema (dict | None): JSON schema for validating structured outputs
|
|
38
44
|
input_schema (dict | None): JSON schema for validating structured inputs
|
|
39
|
-
|
|
40
|
-
Example:
|
|
41
|
-
```python
|
|
42
|
-
class CustomAdapter(BaseAdapter):
|
|
43
|
-
async def _run(self, input: Dict | str) -> Dict | str:
|
|
44
|
-
# Implementation for specific model
|
|
45
|
-
pass
|
|
46
|
-
|
|
47
|
-
def adapter_info(self) -> AdapterInfo:
|
|
48
|
-
return AdapterInfo(
|
|
49
|
-
adapter_name="custom",
|
|
50
|
-
model_name="model-1",
|
|
51
|
-
model_provider="provider",
|
|
52
|
-
prompt_builder_name="simple"
|
|
53
|
-
)
|
|
54
|
-
```
|
|
55
45
|
"""
|
|
56
46
|
|
|
57
47
|
def __init__(
|
|
@@ -85,21 +75,23 @@ class BaseAdapter(metaclass=ABCMeta):
|
|
|
85
75
|
validate_schema(input, self.input_schema)
|
|
86
76
|
|
|
87
77
|
# Run
|
|
88
|
-
|
|
78
|
+
run_output = await self._run(input)
|
|
89
79
|
|
|
90
80
|
# validate output
|
|
91
81
|
if self.output_schema is not None:
|
|
92
|
-
if not isinstance(
|
|
93
|
-
raise RuntimeError(
|
|
94
|
-
|
|
82
|
+
if not isinstance(run_output.output, dict):
|
|
83
|
+
raise RuntimeError(
|
|
84
|
+
f"structured response is not a dict: {run_output.output}"
|
|
85
|
+
)
|
|
86
|
+
validate_schema(run_output.output, self.output_schema)
|
|
95
87
|
else:
|
|
96
|
-
if not isinstance(
|
|
88
|
+
if not isinstance(run_output.output, str):
|
|
97
89
|
raise RuntimeError(
|
|
98
|
-
f"response is not a string for non-structured task: {
|
|
90
|
+
f"response is not a string for non-structured task: {run_output.output}"
|
|
99
91
|
)
|
|
100
92
|
|
|
101
93
|
# Generate the run and output
|
|
102
|
-
run = self.generate_run(input, input_source,
|
|
94
|
+
run = self.generate_run(input, input_source, run_output)
|
|
103
95
|
|
|
104
96
|
# Save the run if configured to do so, and we have a path to save to
|
|
105
97
|
if Config.shared().autosave_runs and self.kiln_task.path is not None:
|
|
@@ -118,27 +110,23 @@ class BaseAdapter(metaclass=ABCMeta):
|
|
|
118
110
|
pass
|
|
119
111
|
|
|
120
112
|
@abstractmethod
|
|
121
|
-
async def _run(self, input: Dict | str) ->
|
|
113
|
+
async def _run(self, input: Dict | str) -> RunOutput:
|
|
122
114
|
pass
|
|
123
115
|
|
|
124
116
|
def build_prompt(self) -> str:
|
|
125
|
-
|
|
126
|
-
adapter_instructions = self.adapter_specific_instructions()
|
|
127
|
-
if adapter_instructions is not None:
|
|
128
|
-
prompt += f"# Format Instructions\n\n{adapter_instructions}\n\n"
|
|
129
|
-
return prompt
|
|
130
|
-
|
|
131
|
-
# override for adapter specific instructions (e.g. tool calling, json format, etc)
|
|
132
|
-
def adapter_specific_instructions(self) -> str | None:
|
|
133
|
-
return None
|
|
117
|
+
return self.prompt_builder.build_prompt()
|
|
134
118
|
|
|
135
119
|
# create a run and task output
|
|
136
120
|
def generate_run(
|
|
137
|
-
self, input: Dict | str, input_source: DataSource | None,
|
|
121
|
+
self, input: Dict | str, input_source: DataSource | None, run_output: RunOutput
|
|
138
122
|
) -> TaskRun:
|
|
139
123
|
# Convert input and output to JSON strings if they are dictionaries
|
|
140
124
|
input_str = json.dumps(input) if isinstance(input, dict) else input
|
|
141
|
-
output_str =
|
|
125
|
+
output_str = (
|
|
126
|
+
json.dumps(run_output.output)
|
|
127
|
+
if isinstance(run_output.output, dict)
|
|
128
|
+
else run_output.output
|
|
129
|
+
)
|
|
142
130
|
|
|
143
131
|
# If no input source is provided, use the human data source
|
|
144
132
|
if input_source is None:
|
|
@@ -159,6 +147,7 @@ class BaseAdapter(metaclass=ABCMeta):
|
|
|
159
147
|
properties=self._properties_for_task_output(),
|
|
160
148
|
),
|
|
161
149
|
),
|
|
150
|
+
intermediate_outputs=run_output.intermediate_outputs,
|
|
162
151
|
)
|
|
163
152
|
|
|
164
153
|
exclude_fields = {
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
# The contents of this file are adapted from the promptwrite library (https://github.com/StacklokLabs/promptwright),
|
|
2
|
+
# which was adapted from the pluto library (https://github.com/redotvideo/pluto).
|
|
3
|
+
# These libraries are licensed under the Apache License 2.0. Any modifications
|
|
4
|
+
# are licensed under the kiln AI Core license (MIT at time of writing). See /libs/core/LICENSE.txt for details.
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
TREE_GENERATION_PROMPT = """I want to train a large language model and I am using another, bigger large language model to generate training data for this. However, if we always ask the bigger model to generate training data with the same prompt, it will end up generating very repetitive training samples. Therefore, we will slightly modify our prompt for each sampling procedure according to some aspects. For instance, when asking the model to generate news articles, we could modify the prompt to let the model tell news articles about particular topics, such as business or politics. To further generate training data, we will do this recursively, and generate submodifications to the prompt. For instance, within the domain of business, we could adapt the prompt to generate news about the stock market or business scandals, and within politics, we could ask the model to generate articles for subtopics like elections or climate policy. We do this recursively, and therefore, we get a tree-like structure of topics.
|
|
8
|
+
Your job is the following: I will give you a path of nodes down the topic tree - you should then come up with a list of new subtopics for this given node and return it as a python list. Here are a few examples of what your outputs should look like, related to the news example I just gave you:
|
|
9
|
+
|
|
10
|
+
Example 1:
|
|
11
|
+
node path: "News Topics" -> "Sports" -> "Football"
|
|
12
|
+
desired number of subtopics: 5
|
|
13
|
+
subtopics: ["College Football", "Football Stadiums", "Health Consequences Football", "Seattle Seahawks", "Football Sponsorships"]
|
|
14
|
+
|
|
15
|
+
Example 2:
|
|
16
|
+
node path: "News Topics" -> "Entertainment" -> "Movies" -> "Star Portraits"
|
|
17
|
+
desired number of subtopics: 8
|
|
18
|
+
subtopics: ["Tom Hanks", "Meryl Streep", "Leonardo DiCaprio", "Jennifer Lawrence", "Denzel Washington", "Charlize Theron", "Robert Downey Jr.", "Emma Stone"]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
Here are three new examples, this time for generating smalltalk topics for a friendly chat assistant:
|
|
22
|
+
|
|
23
|
+
Example 1:
|
|
24
|
+
node path: "Small Talk Topics"
|
|
25
|
+
desired number of subtopics: 7
|
|
26
|
+
subtopics: ["Weather", "Weekend Plans", "Hobbies", "Family", "Books", "Food", "Music"]
|
|
27
|
+
|
|
28
|
+
Example 2:
|
|
29
|
+
node path: "Small Talk Topics" -> "Family"
|
|
30
|
+
desired number of subtopics: 5
|
|
31
|
+
subtopics: ["Parents", "Grandparents", "Siblings", "Family Traditions", "Family Vacations"]
|
|
32
|
+
|
|
33
|
+
Example 3:
|
|
34
|
+
node path: "Small Talk Topics" -> "Hobbies" -> "Cooking"
|
|
35
|
+
desired number of subtopics: 6
|
|
36
|
+
subtopics: ["Recipes", "Asian Food", "Favourite Dishes", "Cookbooks", "Kitchen Gadgets", "Vegan Cooking"]
|
|
37
|
+
|
|
38
|
+
The user message will contain the following:
|
|
39
|
+
- The system prompt for the model we want to train as system_prompt.
|
|
40
|
+
- The node path as node_path. It will be formated as a list of strings from most general to most specific. For example, the node_path for Example 3 above would be ["Small Talk Topics", "Hobbies", "Cooking"]. If empty, the node path is the root node.
|
|
41
|
+
- The desired number of subtopics for this node as num_subtopics. Return exactly this number of subtopics.
|
|
42
|
+
- Optionally, it may contain human_guidance, which is a string that contains additional instructions for how to generate the subtopics.
|
|
43
|
+
- Optionally, it may contain existing_topics, which is a list of subtopics that already exist at this node. You should not generate subtopics that are in this list.
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
When generating subtopics, remain somewhat vague. Things can only be tangentially related and they don't have to be interpreted in a single way. Importantly, make sure that the subtopics fit the system prompt.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
SAMPLE_GENERATION_PROMPT = """I want to train a large language model and you should help me generate training data for it.
|
|
51
|
+
|
|
52
|
+
Your job is to generate a list of potential inputs to the provided system prompt. They should be diverse and relevant to the system prompt, and the topic if provided.
|
|
53
|
+
|
|
54
|
+
In the user message we'll provide the following:
|
|
55
|
+
- The system prompt as system_prompt
|
|
56
|
+
- A potential topic to generate samples for. This will be a list of strings from most general to most specific. For example, the topic ["Small Talk Topics", "Hobbies", "Cooking"] would represent the topic "Cooking" in the "Hobbies" category of "Small Talk Topics". The list may be empty, in which case you should generate samples using the system prompt alone.
|
|
57
|
+
- The number of samples to generate as num_samples. If greater than 1, generate a range of samples that are diverse and relevant to the system prompt, and the topic if provided.
|
|
58
|
+
- The user message may optionally contain human_guidance, which is a string that contains additional instructions for how to generate the samples.
|
|
59
|
+
|
|
60
|
+
The output must be formatted:
|
|
61
|
+
- in the provided structured format, as an object with a single property "generated_samples" that maps to a list of generated samples that would be inputs to the provided system prompt.
|
|
62
|
+
- With the correct number of samples (num_samples).
|
|
63
|
+
- Do not include any other text or break the schema in any way.
|
|
64
|
+
|
|
65
|
+
Example inputs:
|
|
66
|
+
- system_prompt: "You are an assistant that classifies the tone of a tweet. You should output one of the following labels: 'positive', 'negative', 'neutral'."
|
|
67
|
+
- topic: ["Technology", "New iPhone Event"]
|
|
68
|
+
- num_samples: 2
|
|
69
|
+
Example output: {"generated_samples": ["New iPhone looks amazing! I need that camera.", "Another boring event from Apple.", "New iPhone looks interesting, but I'm waiting for reviews."]}
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
Note how the output of this task is data to input to the system prompt, not the expected output of the system prompt.
|
|
73
|
+
"""
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
import json
|
|
2
|
+
|
|
3
|
+
from kiln_ai.adapters.prompt_builders import SimplePromptBuilder
|
|
4
|
+
from kiln_ai.datamodel import Project, Task
|
|
5
|
+
from pydantic import BaseModel
|
|
6
|
+
|
|
7
|
+
from .data_gen_prompts import (
|
|
8
|
+
SAMPLE_GENERATION_PROMPT,
|
|
9
|
+
TREE_GENERATION_PROMPT,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class DataGenCategoriesTaskInput(BaseModel):
|
|
14
|
+
node_path: list[str]
|
|
15
|
+
system_prompt: str
|
|
16
|
+
num_subtopics: int
|
|
17
|
+
human_guidance: str | None = None
|
|
18
|
+
existing_topics: list[str] | None = None
|
|
19
|
+
|
|
20
|
+
@classmethod
|
|
21
|
+
def from_task(
|
|
22
|
+
cls,
|
|
23
|
+
task: Task,
|
|
24
|
+
node_path: list[str] = [],
|
|
25
|
+
num_subtopics: int = 6,
|
|
26
|
+
human_guidance: str | None = None,
|
|
27
|
+
existing_topics: list[str] | None = None,
|
|
28
|
+
) -> "DataGenCategoriesTaskInput":
|
|
29
|
+
prompt_builder = SimplePromptBuilder(task=task)
|
|
30
|
+
return cls(
|
|
31
|
+
node_path=node_path,
|
|
32
|
+
num_subtopics=num_subtopics,
|
|
33
|
+
human_guidance=human_guidance,
|
|
34
|
+
existing_topics=existing_topics,
|
|
35
|
+
system_prompt=prompt_builder.build_prompt(),
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class DataGenCategoriesTaskOutput(BaseModel):
|
|
40
|
+
subtopics: list[str]
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class DataGenCategoriesTask(Task, parent_of={}):
|
|
44
|
+
def __init__(self):
|
|
45
|
+
# Keep the typechecker happy. TODO: shouldn't need this or parent_of above.
|
|
46
|
+
tmp_project = Project(name="DataGen")
|
|
47
|
+
super().__init__(
|
|
48
|
+
name="DataGen",
|
|
49
|
+
parent=tmp_project,
|
|
50
|
+
description="A task which generates synthetic data categories, which in turn are used to generate training data for a model to learn from.",
|
|
51
|
+
instruction=TREE_GENERATION_PROMPT,
|
|
52
|
+
input_json_schema=json.dumps(
|
|
53
|
+
DataGenCategoriesTaskInput.model_json_schema()
|
|
54
|
+
),
|
|
55
|
+
output_json_schema=json.dumps(
|
|
56
|
+
DataGenCategoriesTaskOutput.model_json_schema()
|
|
57
|
+
),
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class DataGenSampleTaskInput(BaseModel):
|
|
62
|
+
topic: list[str]
|
|
63
|
+
system_prompt: str
|
|
64
|
+
num_samples: int
|
|
65
|
+
human_guidance: str | None = None
|
|
66
|
+
|
|
67
|
+
@classmethod
|
|
68
|
+
def from_task(
|
|
69
|
+
cls,
|
|
70
|
+
task: Task,
|
|
71
|
+
topic: list[str] = [],
|
|
72
|
+
num_samples: int = 8,
|
|
73
|
+
human_guidance: str | None = None,
|
|
74
|
+
) -> "DataGenSampleTaskInput":
|
|
75
|
+
prompt_builder = SimplePromptBuilder(task=task)
|
|
76
|
+
return cls(
|
|
77
|
+
topic=topic,
|
|
78
|
+
num_samples=num_samples,
|
|
79
|
+
human_guidance=human_guidance,
|
|
80
|
+
system_prompt=prompt_builder.build_prompt(),
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def list_json_schema_for_task(task: Task) -> str:
|
|
85
|
+
if task.input_json_schema:
|
|
86
|
+
items_schema = json.loads(task.input_json_schema)
|
|
87
|
+
else:
|
|
88
|
+
items_schema = {"type": "string"}
|
|
89
|
+
|
|
90
|
+
list_schema = {
|
|
91
|
+
"type": "array",
|
|
92
|
+
"items": items_schema,
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
top_level_schema = {
|
|
96
|
+
"type": "object",
|
|
97
|
+
"properties": {
|
|
98
|
+
"generated_samples": list_schema,
|
|
99
|
+
},
|
|
100
|
+
"required": ["generated_samples"],
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
return json.dumps(top_level_schema)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class DataGenSampleTask(Task, parent_of={}):
|
|
107
|
+
def __init__(self, target_task: Task, num_samples: int = 8):
|
|
108
|
+
# Keep the typechecker happy. TODO: shouldn't need this or parent_of above.
|
|
109
|
+
tmp_project = Project(name="DataGenSample")
|
|
110
|
+
super().__init__(
|
|
111
|
+
name="DataGenSample",
|
|
112
|
+
parent=tmp_project,
|
|
113
|
+
description="A task which generates synthetic data samples for a given topic (and optional subtopic).",
|
|
114
|
+
instruction=SAMPLE_GENERATION_PROMPT,
|
|
115
|
+
input_json_schema=json.dumps(DataGenSampleTaskInput.model_json_schema()),
|
|
116
|
+
output_json_schema=list_json_schema_for_task(target_task),
|
|
117
|
+
)
|
|
@@ -0,0 +1,292 @@
|
|
|
1
|
+
import json
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
from kiln_ai.adapters.data_gen.data_gen_task import (
|
|
5
|
+
DataGenCategoriesTask,
|
|
6
|
+
DataGenCategoriesTaskInput,
|
|
7
|
+
DataGenCategoriesTaskOutput,
|
|
8
|
+
DataGenSampleTask,
|
|
9
|
+
DataGenSampleTaskInput,
|
|
10
|
+
list_json_schema_for_task,
|
|
11
|
+
)
|
|
12
|
+
from kiln_ai.adapters.langchain_adapters import LangChainPromptAdapter
|
|
13
|
+
from kiln_ai.adapters.ml_model_list import get_model_and_provider
|
|
14
|
+
from kiln_ai.adapters.test_prompt_adaptors import get_all_models_and_providers
|
|
15
|
+
from kiln_ai.datamodel import Project, Task
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@pytest.fixture
|
|
19
|
+
def base_task():
|
|
20
|
+
project = Project(name="TestProject")
|
|
21
|
+
return Task(
|
|
22
|
+
name="Cowboy Speaker",
|
|
23
|
+
parent=project,
|
|
24
|
+
description="Reply like a cowboy",
|
|
25
|
+
instruction="Reply like a cowboy",
|
|
26
|
+
requirements=[],
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def test_data_gen_categories_task_input_initialization(base_task):
|
|
31
|
+
# Arrange
|
|
32
|
+
node_path = ["root", "branch", "leaf"]
|
|
33
|
+
num_subtopics = 4
|
|
34
|
+
human_guidance = "Test guidance"
|
|
35
|
+
|
|
36
|
+
# Act
|
|
37
|
+
input_model = DataGenCategoriesTaskInput.from_task(
|
|
38
|
+
task=base_task,
|
|
39
|
+
node_path=node_path,
|
|
40
|
+
num_subtopics=num_subtopics,
|
|
41
|
+
human_guidance=human_guidance,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
# Assert
|
|
45
|
+
assert input_model.node_path == node_path
|
|
46
|
+
assert input_model.num_subtopics == num_subtopics
|
|
47
|
+
assert input_model.human_guidance == human_guidance
|
|
48
|
+
assert isinstance(input_model.system_prompt, str)
|
|
49
|
+
assert "Reply like a cowboy" in input_model.system_prompt
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def test_data_gen_categories_task_input_default_values(base_task):
|
|
53
|
+
# Act
|
|
54
|
+
input_model = DataGenCategoriesTaskInput.from_task(task=base_task)
|
|
55
|
+
|
|
56
|
+
# Assert
|
|
57
|
+
assert input_model.num_subtopics == 6
|
|
58
|
+
assert input_model.human_guidance is None
|
|
59
|
+
assert input_model.node_path == []
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def test_data_gen_categories_task_initialization():
|
|
63
|
+
# Act
|
|
64
|
+
task = DataGenCategoriesTask()
|
|
65
|
+
|
|
66
|
+
# Assert
|
|
67
|
+
assert task.name == "DataGen"
|
|
68
|
+
assert isinstance(task.parent, Project)
|
|
69
|
+
assert task.description is not None
|
|
70
|
+
assert task.instruction is not None
|
|
71
|
+
assert isinstance(task.input_json_schema, str)
|
|
72
|
+
assert isinstance(task.output_json_schema, str)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def test_data_gen_categories_task_schemas():
|
|
76
|
+
# Act
|
|
77
|
+
task = DataGenCategoriesTask()
|
|
78
|
+
|
|
79
|
+
# Assert
|
|
80
|
+
input_schema = json.loads(task.input_json_schema)
|
|
81
|
+
output_schema = json.loads(task.output_json_schema)
|
|
82
|
+
|
|
83
|
+
assert isinstance(input_schema, dict)
|
|
84
|
+
assert isinstance(output_schema, dict)
|
|
85
|
+
assert output_schema["type"] == "object"
|
|
86
|
+
assert output_schema["properties"]["subtopics"]["type"] == "array"
|
|
87
|
+
assert input_schema["properties"]["node_path"]["type"] == "array"
|
|
88
|
+
assert input_schema["properties"]["num_subtopics"]["type"] == "integer"
|
|
89
|
+
assert set(input_schema["required"]) == {
|
|
90
|
+
"node_path",
|
|
91
|
+
"num_subtopics",
|
|
92
|
+
"system_prompt",
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
@pytest.mark.paid
|
|
97
|
+
@pytest.mark.ollama
|
|
98
|
+
@pytest.mark.parametrize("model_name,provider_name", get_all_models_and_providers())
|
|
99
|
+
async def test_data_gen_all_models_providers(
|
|
100
|
+
tmp_path, model_name, provider_name, base_task
|
|
101
|
+
):
|
|
102
|
+
_, provider = get_model_and_provider(model_name, provider_name)
|
|
103
|
+
if not provider.supports_data_gen:
|
|
104
|
+
# pass if the model doesn't support data gen (testing the support flag is part of this)
|
|
105
|
+
return
|
|
106
|
+
|
|
107
|
+
data_gen_task = DataGenCategoriesTask()
|
|
108
|
+
data_gen_input = DataGenCategoriesTaskInput.from_task(base_task, num_subtopics=6)
|
|
109
|
+
|
|
110
|
+
adapter = LangChainPromptAdapter(
|
|
111
|
+
data_gen_task,
|
|
112
|
+
model_name=model_name,
|
|
113
|
+
provider=provider_name,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
input_dict = data_gen_input.model_dump()
|
|
117
|
+
run = await adapter.invoke(input_dict)
|
|
118
|
+
parsed_output = DataGenCategoriesTaskOutput.model_validate_json(run.output.output)
|
|
119
|
+
assert len(parsed_output.subtopics) == 6
|
|
120
|
+
for subtopic in parsed_output.subtopics:
|
|
121
|
+
assert isinstance(subtopic, str)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def test_data_gen_sample_task_input_initialization(base_task):
|
|
125
|
+
# Arrange
|
|
126
|
+
topic = ["cowboys", "hats"]
|
|
127
|
+
num_samples = 4
|
|
128
|
+
human_guidance = "Test guidance"
|
|
129
|
+
|
|
130
|
+
# Act
|
|
131
|
+
input_model = DataGenSampleTaskInput.from_task(
|
|
132
|
+
task=base_task,
|
|
133
|
+
topic=topic,
|
|
134
|
+
num_samples=num_samples,
|
|
135
|
+
human_guidance=human_guidance,
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
# Assert
|
|
139
|
+
assert input_model.topic == topic
|
|
140
|
+
assert input_model.num_samples == num_samples
|
|
141
|
+
assert input_model.human_guidance == human_guidance
|
|
142
|
+
assert isinstance(input_model.system_prompt, str)
|
|
143
|
+
assert "Reply like a cowboy" in input_model.system_prompt
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def test_data_gen_sample_task_input_default_values(base_task):
|
|
147
|
+
# Act
|
|
148
|
+
input_model = DataGenSampleTaskInput.from_task(task=base_task)
|
|
149
|
+
|
|
150
|
+
# Assert
|
|
151
|
+
assert input_model.num_samples == 8
|
|
152
|
+
assert input_model.human_guidance is None
|
|
153
|
+
assert input_model.topic == []
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def test_data_gen_sample_task_initialization(base_task):
|
|
157
|
+
# Act
|
|
158
|
+
task = DataGenSampleTask(target_task=base_task)
|
|
159
|
+
|
|
160
|
+
# Assert
|
|
161
|
+
assert task.name == "DataGenSample"
|
|
162
|
+
assert isinstance(task.parent, Project)
|
|
163
|
+
assert task.description is not None
|
|
164
|
+
assert task.instruction is not None
|
|
165
|
+
|
|
166
|
+
input_schema = json.loads(task.input_json_schema)
|
|
167
|
+
output_schema = json.loads(task.output_json_schema)
|
|
168
|
+
|
|
169
|
+
assert isinstance(input_schema, dict)
|
|
170
|
+
assert isinstance(output_schema, dict)
|
|
171
|
+
assert output_schema["type"] == "object"
|
|
172
|
+
assert output_schema["properties"]["generated_samples"]["type"] == "array"
|
|
173
|
+
assert input_schema["properties"]["topic"]["type"] == "array"
|
|
174
|
+
assert input_schema["properties"]["num_samples"]["type"] == "integer"
|
|
175
|
+
assert set(input_schema["required"]) == {
|
|
176
|
+
"topic",
|
|
177
|
+
"num_samples",
|
|
178
|
+
"system_prompt",
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def test_list_json_schema_for_task_with_output_schema(base_task):
|
|
183
|
+
# Arrange
|
|
184
|
+
base_task.input_json_schema = json.dumps(
|
|
185
|
+
{
|
|
186
|
+
"type": "object",
|
|
187
|
+
"properties": {"name": {"type": "string"}, "age": {"type": "integer"}},
|
|
188
|
+
}
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
# Act
|
|
192
|
+
schema = list_json_schema_for_task(base_task)
|
|
193
|
+
parsed_schema = json.loads(schema)
|
|
194
|
+
|
|
195
|
+
# Assert
|
|
196
|
+
assert parsed_schema["type"] == "object"
|
|
197
|
+
generated_samples_schema = parsed_schema["properties"]["generated_samples"]
|
|
198
|
+
assert generated_samples_schema["type"] == "array"
|
|
199
|
+
assert generated_samples_schema["items"]["type"] == "object"
|
|
200
|
+
assert generated_samples_schema["items"]["properties"]["name"]["type"] == "string"
|
|
201
|
+
assert generated_samples_schema["items"]["properties"]["age"]["type"] == "integer"
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def test_list_json_schema_for_task_without_output_schema(base_task):
|
|
205
|
+
# Arrange
|
|
206
|
+
base_task.output_json_schema = None
|
|
207
|
+
|
|
208
|
+
# Act
|
|
209
|
+
schema = list_json_schema_for_task(base_task)
|
|
210
|
+
parsed_schema = json.loads(schema)
|
|
211
|
+
|
|
212
|
+
# Assert
|
|
213
|
+
assert parsed_schema["type"] == "object"
|
|
214
|
+
assert parsed_schema["properties"]["generated_samples"]["type"] == "array"
|
|
215
|
+
assert parsed_schema["properties"]["generated_samples"]["items"]["type"] == "string"
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
@pytest.mark.paid
|
|
219
|
+
@pytest.mark.ollama
|
|
220
|
+
@pytest.mark.parametrize("model_name,provider_name", get_all_models_and_providers())
|
|
221
|
+
async def test_data_gen_sample_all_models_providers(
|
|
222
|
+
tmp_path, model_name, provider_name, base_task
|
|
223
|
+
):
|
|
224
|
+
_, provider = get_model_and_provider(model_name, provider_name)
|
|
225
|
+
if not provider.supports_data_gen:
|
|
226
|
+
# pass if the model doesn't support data gen (testing the support flag is part of this)
|
|
227
|
+
return
|
|
228
|
+
|
|
229
|
+
data_gen_task = DataGenSampleTask(target_task=base_task)
|
|
230
|
+
data_gen_input = DataGenSampleTaskInput.from_task(
|
|
231
|
+
base_task, topic=["riding horses"], num_samples=4
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
adapter = LangChainPromptAdapter(
|
|
235
|
+
data_gen_task,
|
|
236
|
+
model_name=model_name,
|
|
237
|
+
provider=provider_name,
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
input_dict = data_gen_input.model_dump()
|
|
241
|
+
run = await adapter.invoke(input_dict)
|
|
242
|
+
parsed_output = json.loads(run.output.output)
|
|
243
|
+
samples = parsed_output["generated_samples"]
|
|
244
|
+
assert len(samples) == 4
|
|
245
|
+
for sample in samples:
|
|
246
|
+
assert isinstance(sample, str)
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
@pytest.mark.paid
|
|
250
|
+
@pytest.mark.ollama
|
|
251
|
+
@pytest.mark.parametrize("model_name,provider_name", get_all_models_and_providers())
|
|
252
|
+
async def test_data_gen_sample_all_models_providers_with_structured_output(
|
|
253
|
+
tmp_path, model_name, provider_name, base_task
|
|
254
|
+
):
|
|
255
|
+
base_task.output_json_schema = json.dumps(
|
|
256
|
+
{
|
|
257
|
+
"type": "object",
|
|
258
|
+
"properties": {
|
|
259
|
+
"opening": {"type": "string"},
|
|
260
|
+
"closing": {"type": "string"},
|
|
261
|
+
},
|
|
262
|
+
"required": ["opening", "closing"],
|
|
263
|
+
}
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
_, provider = get_model_and_provider(model_name, provider_name)
|
|
267
|
+
if not provider.supports_data_gen:
|
|
268
|
+
# pass if the model doesn't support data gen (testing the support flag is part of this)
|
|
269
|
+
return
|
|
270
|
+
|
|
271
|
+
data_gen_task = DataGenSampleTask(target_task=base_task)
|
|
272
|
+
data_gen_input = DataGenSampleTaskInput.from_task(
|
|
273
|
+
base_task, topic=["riding horses"], num_samples=4
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
adapter = LangChainPromptAdapter(
|
|
277
|
+
data_gen_task,
|
|
278
|
+
model_name=model_name,
|
|
279
|
+
provider=provider_name,
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
input_dict = data_gen_input.model_dump()
|
|
283
|
+
run = await adapter.invoke(input_dict)
|
|
284
|
+
parsed_output = json.loads(run.output.output)
|
|
285
|
+
samples = parsed_output["generated_samples"]
|
|
286
|
+
assert len(samples) == 4
|
|
287
|
+
for sample in samples:
|
|
288
|
+
assert isinstance(sample, dict)
|
|
289
|
+
assert "opening" in sample
|
|
290
|
+
assert "closing" in sample
|
|
291
|
+
assert isinstance(sample["opening"], str)
|
|
292
|
+
assert isinstance(sample["closing"], str)
|
|
@@ -2,14 +2,14 @@ from typing import Dict
|
|
|
2
2
|
|
|
3
3
|
from langchain_core.language_models import LanguageModelInput
|
|
4
4
|
from langchain_core.language_models.chat_models import BaseChatModel
|
|
5
|
-
from langchain_core.messages import HumanMessage, SystemMessage
|
|
5
|
+
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
|
6
6
|
from langchain_core.messages.base import BaseMessage
|
|
7
7
|
from langchain_core.runnables import Runnable
|
|
8
8
|
from pydantic import BaseModel
|
|
9
9
|
|
|
10
10
|
import kiln_ai.datamodel as datamodel
|
|
11
11
|
|
|
12
|
-
from .base_adapter import AdapterInfo, BaseAdapter, BasePromptBuilder
|
|
12
|
+
from .base_adapter import AdapterInfo, BaseAdapter, BasePromptBuilder, RunOutput
|
|
13
13
|
from .ml_model_list import langchain_model_from
|
|
14
14
|
|
|
15
15
|
LangChainModelType = BaseChatModel | Runnable[LanguageModelInput, Dict | BaseModel]
|
|
@@ -84,15 +84,41 @@ class LangChainPromptAdapter(BaseAdapter):
|
|
|
84
84
|
)
|
|
85
85
|
return self._model
|
|
86
86
|
|
|
87
|
-
async def _run(self, input: Dict | str) ->
|
|
87
|
+
async def _run(self, input: Dict | str) -> RunOutput:
|
|
88
|
+
model = await self.model()
|
|
89
|
+
chain = model
|
|
90
|
+
intermediate_outputs = {}
|
|
91
|
+
|
|
88
92
|
prompt = self.build_prompt()
|
|
89
93
|
user_msg = self.prompt_builder.build_user_message(input)
|
|
90
94
|
messages = [
|
|
91
95
|
SystemMessage(content=prompt),
|
|
92
96
|
HumanMessage(content=user_msg),
|
|
93
97
|
]
|
|
94
|
-
|
|
95
|
-
|
|
98
|
+
|
|
99
|
+
# COT with structured output
|
|
100
|
+
cot_prompt = self.prompt_builder.chain_of_thought_prompt()
|
|
101
|
+
if cot_prompt and self.has_structured_output():
|
|
102
|
+
# Base model (without structured output) used for COT message
|
|
103
|
+
base_model = await langchain_model_from(
|
|
104
|
+
self.model_name, self.model_provider
|
|
105
|
+
)
|
|
106
|
+
messages.append(
|
|
107
|
+
SystemMessage(content=cot_prompt),
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
cot_messages = [*messages]
|
|
111
|
+
cot_response = base_model.invoke(cot_messages)
|
|
112
|
+
intermediate_outputs["chain_of_thought"] = cot_response.content
|
|
113
|
+
messages.append(AIMessage(content=cot_response.content))
|
|
114
|
+
messages.append(
|
|
115
|
+
SystemMessage(content="Considering the above, return a final result.")
|
|
116
|
+
)
|
|
117
|
+
elif cot_prompt:
|
|
118
|
+
# for plaintext output, we just add COT instructions. We still only make one call.
|
|
119
|
+
messages.append(SystemMessage(content=cot_prompt))
|
|
120
|
+
|
|
121
|
+
response = chain.invoke(messages)
|
|
96
122
|
|
|
97
123
|
if self.has_structured_output():
|
|
98
124
|
if (
|
|
@@ -102,14 +128,20 @@ class LangChainPromptAdapter(BaseAdapter):
|
|
|
102
128
|
):
|
|
103
129
|
raise RuntimeError(f"structured response not returned: {response}")
|
|
104
130
|
structured_response = response["parsed"]
|
|
105
|
-
return
|
|
131
|
+
return RunOutput(
|
|
132
|
+
output=self._munge_response(structured_response),
|
|
133
|
+
intermediate_outputs=intermediate_outputs,
|
|
134
|
+
)
|
|
106
135
|
else:
|
|
107
136
|
if not isinstance(response, BaseMessage):
|
|
108
137
|
raise RuntimeError(f"response is not a BaseMessage: {response}")
|
|
109
138
|
text_content = response.content
|
|
110
139
|
if not isinstance(text_content, str):
|
|
111
140
|
raise RuntimeError(f"response is not a string: {text_content}")
|
|
112
|
-
return
|
|
141
|
+
return RunOutput(
|
|
142
|
+
output=text_content,
|
|
143
|
+
intermediate_outputs=intermediate_outputs,
|
|
144
|
+
)
|
|
113
145
|
|
|
114
146
|
def adapter_info(self) -> AdapterInfo:
|
|
115
147
|
return AdapterInfo(
|