kiln-ai 0.5.5__tar.gz → 0.6.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (73) hide show
  1. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/PKG-INFO +1 -2
  2. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/kiln_ai/adapters/base_adapter.py +24 -35
  3. kiln_ai-0.6.0/kiln_ai/adapters/data_gen/data_gen_prompts.py +73 -0
  4. kiln_ai-0.6.0/kiln_ai/adapters/data_gen/data_gen_task.py +117 -0
  5. kiln_ai-0.6.0/kiln_ai/adapters/data_gen/test_data_gen_task.py +292 -0
  6. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/kiln_ai/adapters/langchain_adapters.py +39 -7
  7. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/kiln_ai/adapters/ml_model_list.py +55 -1
  8. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/kiln_ai/adapters/prompt_builders.py +66 -0
  9. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/kiln_ai/adapters/repair/test_repair_task.py +4 -1
  10. kiln_ai-0.6.0/kiln_ai/adapters/test_langchain_adapter.py +124 -0
  11. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/kiln_ai/adapters/test_ml_model_list.py +56 -0
  12. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/kiln_ai/adapters/test_prompt_adaptors.py +52 -18
  13. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/kiln_ai/adapters/test_prompt_builders.py +97 -7
  14. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/kiln_ai/adapters/test_saving_adapter_results.py +16 -6
  15. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/kiln_ai/adapters/test_structured_output.py +33 -5
  16. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/kiln_ai/datamodel/__init__.py +28 -7
  17. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/kiln_ai/datamodel/json_schema.py +1 -0
  18. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/kiln_ai/datamodel/test_models.py +44 -8
  19. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/kiln_ai/utils/config.py +3 -2
  20. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/kiln_ai/utils/test_config.py +7 -0
  21. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/pyproject.toml +1 -1
  22. kiln_ai-0.5.5/kiln_ai/adapters/test_langchain_adapter.py +0 -51
  23. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/.gitignore +0 -0
  24. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/.python-version +0 -0
  25. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/LICENSE.txt +0 -0
  26. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/README.md +0 -0
  27. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/docs/core_library_docs/index.html +0 -0
  28. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/docs/core_library_docs/kiln_ai/adapters/base_adapter.html +0 -0
  29. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/docs/core_library_docs/kiln_ai/adapters/langchain_adapters.html +0 -0
  30. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/docs/core_library_docs/kiln_ai/adapters/ml_model_list.html +0 -0
  31. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/docs/core_library_docs/kiln_ai/adapters/prompt_builders.html +0 -0
  32. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/docs/core_library_docs/kiln_ai/adapters/repair/repair_task.html +0 -0
  33. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/docs/core_library_docs/kiln_ai/adapters/repair.html +0 -0
  34. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/docs/core_library_docs/kiln_ai/adapters.html +0 -0
  35. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/docs/core_library_docs/kiln_ai/datamodel/basemodel.html +0 -0
  36. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/docs/core_library_docs/kiln_ai/datamodel/json_schema.html +0 -0
  37. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/docs/core_library_docs/kiln_ai/datamodel.html +0 -0
  38. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/docs/core_library_docs/kiln_ai/utils/config.html +0 -0
  39. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/docs/core_library_docs/kiln_ai/utils/formatting.html +0 -0
  40. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/docs/core_library_docs/kiln_ai/utils.html +0 -0
  41. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/docs/core_library_docs/kiln_ai.html +0 -0
  42. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/docs/core_library_docs/search.js +0 -0
  43. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/docs/kiln_core_docs/index.html +0 -0
  44. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/docs/kiln_core_docs/kiln_ai/adapters/base_adapter.html +0 -0
  45. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/docs/kiln_core_docs/kiln_ai/adapters/langchain_adapters.html +0 -0
  46. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/docs/kiln_core_docs/kiln_ai/adapters/ml_model_list.html +0 -0
  47. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/docs/kiln_core_docs/kiln_ai/adapters/prompt_builders.html +0 -0
  48. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/docs/kiln_core_docs/kiln_ai/adapters/repair/repair_task.html +0 -0
  49. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/docs/kiln_core_docs/kiln_ai/adapters/repair.html +0 -0
  50. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/docs/kiln_core_docs/kiln_ai/adapters.html +0 -0
  51. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/docs/kiln_core_docs/kiln_ai/datamodel/basemodel.html +0 -0
  52. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/docs/kiln_core_docs/kiln_ai/datamodel/json_schema.html +0 -0
  53. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/docs/kiln_core_docs/kiln_ai/datamodel.html +0 -0
  54. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/docs/kiln_core_docs/kiln_ai/utils/config.html +0 -0
  55. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/docs/kiln_core_docs/kiln_ai/utils/formatting.html +0 -0
  56. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/docs/kiln_core_docs/kiln_ai/utils.html +0 -0
  57. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/docs/kiln_core_docs/kiln_ai.html +0 -0
  58. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/docs/kiln_core_docs/search.js +0 -0
  59. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/kiln_ai/__init__.py +0 -0
  60. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/kiln_ai/adapters/__init__.py +0 -0
  61. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/kiln_ai/adapters/repair/__init__.py +0 -0
  62. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/kiln_ai/adapters/repair/repair_task.py +0 -0
  63. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/kiln_ai/datamodel/basemodel.py +0 -0
  64. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/kiln_ai/datamodel/test_basemodel.py +0 -0
  65. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/kiln_ai/datamodel/test_datasource.py +0 -0
  66. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/kiln_ai/datamodel/test_example_models.py +0 -0
  67. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/kiln_ai/datamodel/test_json_schema.py +0 -0
  68. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/kiln_ai/datamodel/test_nested_save.py +0 -0
  69. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/kiln_ai/datamodel/test_output_rating.py +0 -0
  70. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/kiln_ai/utils/__init__.py +0 -0
  71. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/kiln_ai/utils/formatting.py +0 -0
  72. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/setup.cfg +0 -0
  73. {kiln_ai-0.5.5 → kiln_ai-0.6.0}/uv.lock +0 -0
@@ -1,13 +1,12 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: kiln-ai
3
- Version: 0.5.5
3
+ Version: 0.6.0
4
4
  Summary: Kiln AI
5
5
  Project-URL: Homepage, https://getkiln.ai
6
6
  Project-URL: Repository, https://github.com/Kiln-AI/kiln
7
7
  Project-URL: Documentation, https://kiln-ai.github.io/Kiln/kiln_core_docs/kiln_ai.html
8
8
  Project-URL: Issues, https://github.com/Kiln-AI/kiln/issues
9
9
  Author-email: "Steve Cosman, Chesterfield Laboratories Inc" <scosman@users.noreply.github.com>
10
- License-File: LICENSE.txt
11
10
  Classifier: Intended Audience :: Developers
12
11
  Classifier: License :: OSI Approved :: MIT License
13
12
  Classifier: Programming Language :: Python :: 3.10
@@ -24,6 +24,12 @@ class AdapterInfo:
24
24
  prompt_builder_name: str
25
25
 
26
26
 
27
+ @dataclass
28
+ class RunOutput:
29
+ output: Dict | str
30
+ intermediate_outputs: Dict[str, str] | None
31
+
32
+
27
33
  class BaseAdapter(metaclass=ABCMeta):
28
34
  """Base class for AI model adapters that handle task execution.
29
35
 
@@ -36,22 +42,6 @@ class BaseAdapter(metaclass=ABCMeta):
36
42
  kiln_task (Task): The task configuration and metadata
37
43
  output_schema (dict | None): JSON schema for validating structured outputs
38
44
  input_schema (dict | None): JSON schema for validating structured inputs
39
-
40
- Example:
41
- ```python
42
- class CustomAdapter(BaseAdapter):
43
- async def _run(self, input: Dict | str) -> Dict | str:
44
- # Implementation for specific model
45
- pass
46
-
47
- def adapter_info(self) -> AdapterInfo:
48
- return AdapterInfo(
49
- adapter_name="custom",
50
- model_name="model-1",
51
- model_provider="provider",
52
- prompt_builder_name="simple"
53
- )
54
- ```
55
45
  """
56
46
 
57
47
  def __init__(
@@ -85,21 +75,23 @@ class BaseAdapter(metaclass=ABCMeta):
85
75
  validate_schema(input, self.input_schema)
86
76
 
87
77
  # Run
88
- result = await self._run(input)
78
+ run_output = await self._run(input)
89
79
 
90
80
  # validate output
91
81
  if self.output_schema is not None:
92
- if not isinstance(result, dict):
93
- raise RuntimeError(f"structured response is not a dict: {result}")
94
- validate_schema(result, self.output_schema)
82
+ if not isinstance(run_output.output, dict):
83
+ raise RuntimeError(
84
+ f"structured response is not a dict: {run_output.output}"
85
+ )
86
+ validate_schema(run_output.output, self.output_schema)
95
87
  else:
96
- if not isinstance(result, str):
88
+ if not isinstance(run_output.output, str):
97
89
  raise RuntimeError(
98
- f"response is not a string for non-structured task: {result}"
90
+ f"response is not a string for non-structured task: {run_output.output}"
99
91
  )
100
92
 
101
93
  # Generate the run and output
102
- run = self.generate_run(input, input_source, result)
94
+ run = self.generate_run(input, input_source, run_output)
103
95
 
104
96
  # Save the run if configured to do so, and we have a path to save to
105
97
  if Config.shared().autosave_runs and self.kiln_task.path is not None:
@@ -118,27 +110,23 @@ class BaseAdapter(metaclass=ABCMeta):
118
110
  pass
119
111
 
120
112
  @abstractmethod
121
- async def _run(self, input: Dict | str) -> Dict | str:
113
+ async def _run(self, input: Dict | str) -> RunOutput:
122
114
  pass
123
115
 
124
116
  def build_prompt(self) -> str:
125
- prompt = self.prompt_builder.build_prompt()
126
- adapter_instructions = self.adapter_specific_instructions()
127
- if adapter_instructions is not None:
128
- prompt += f"# Format Instructions\n\n{adapter_instructions}\n\n"
129
- return prompt
130
-
131
- # override for adapter specific instructions (e.g. tool calling, json format, etc)
132
- def adapter_specific_instructions(self) -> str | None:
133
- return None
117
+ return self.prompt_builder.build_prompt()
134
118
 
135
119
  # create a run and task output
136
120
  def generate_run(
137
- self, input: Dict | str, input_source: DataSource | None, output: Dict | str
121
+ self, input: Dict | str, input_source: DataSource | None, run_output: RunOutput
138
122
  ) -> TaskRun:
139
123
  # Convert input and output to JSON strings if they are dictionaries
140
124
  input_str = json.dumps(input) if isinstance(input, dict) else input
141
- output_str = json.dumps(output) if isinstance(output, dict) else output
125
+ output_str = (
126
+ json.dumps(run_output.output)
127
+ if isinstance(run_output.output, dict)
128
+ else run_output.output
129
+ )
142
130
 
143
131
  # If no input source is provided, use the human data source
144
132
  if input_source is None:
@@ -159,6 +147,7 @@ class BaseAdapter(metaclass=ABCMeta):
159
147
  properties=self._properties_for_task_output(),
160
148
  ),
161
149
  ),
150
+ intermediate_outputs=run_output.intermediate_outputs,
162
151
  )
163
152
 
164
153
  exclude_fields = {
@@ -0,0 +1,73 @@
1
+ # The contents of this file are adapted from the promptwrite library (https://github.com/StacklokLabs/promptwright),
2
+ # which was adapted from the pluto library (https://github.com/redotvideo/pluto).
3
+ # These libraries are licensed under the Apache License 2.0. Any modifications
4
+ # are licensed under the kiln AI Core license (MIT at time of writing). See /libs/core/LICENSE.txt for details.
5
+
6
+
7
+ TREE_GENERATION_PROMPT = """I want to train a large language model and I am using another, bigger large language model to generate training data for this. However, if we always ask the bigger model to generate training data with the same prompt, it will end up generating very repetitive training samples. Therefore, we will slightly modify our prompt for each sampling procedure according to some aspects. For instance, when asking the model to generate news articles, we could modify the prompt to let the model tell news articles about particular topics, such as business or politics. To further generate training data, we will do this recursively, and generate submodifications to the prompt. For instance, within the domain of business, we could adapt the prompt to generate news about the stock market or business scandals, and within politics, we could ask the model to generate articles for subtopics like elections or climate policy. We do this recursively, and therefore, we get a tree-like structure of topics.
8
+ Your job is the following: I will give you a path of nodes down the topic tree - you should then come up with a list of new subtopics for this given node and return it as a python list. Here are a few examples of what your outputs should look like, related to the news example I just gave you:
9
+
10
+ Example 1:
11
+ node path: "News Topics" -> "Sports" -> "Football"
12
+ desired number of subtopics: 5
13
+ subtopics: ["College Football", "Football Stadiums", "Health Consequences Football", "Seattle Seahawks", "Football Sponsorships"]
14
+
15
+ Example 2:
16
+ node path: "News Topics" -> "Entertainment" -> "Movies" -> "Star Portraits"
17
+ desired number of subtopics: 8
18
+ subtopics: ["Tom Hanks", "Meryl Streep", "Leonardo DiCaprio", "Jennifer Lawrence", "Denzel Washington", "Charlize Theron", "Robert Downey Jr.", "Emma Stone"]
19
+
20
+
21
+ Here are three new examples, this time for generating smalltalk topics for a friendly chat assistant:
22
+
23
+ Example 1:
24
+ node path: "Small Talk Topics"
25
+ desired number of subtopics: 7
26
+ subtopics: ["Weather", "Weekend Plans", "Hobbies", "Family", "Books", "Food", "Music"]
27
+
28
+ Example 2:
29
+ node path: "Small Talk Topics" -> "Family"
30
+ desired number of subtopics: 5
31
+ subtopics: ["Parents", "Grandparents", "Siblings", "Family Traditions", "Family Vacations"]
32
+
33
+ Example 3:
34
+ node path: "Small Talk Topics" -> "Hobbies" -> "Cooking"
35
+ desired number of subtopics: 6
36
+ subtopics: ["Recipes", "Asian Food", "Favourite Dishes", "Cookbooks", "Kitchen Gadgets", "Vegan Cooking"]
37
+
38
+ The user message will contain the following:
39
+ - The system prompt for the model we want to train as system_prompt.
40
+ - The node path as node_path. It will be formated as a list of strings from most general to most specific. For example, the node_path for Example 3 above would be ["Small Talk Topics", "Hobbies", "Cooking"]. If empty, the node path is the root node.
41
+ - The desired number of subtopics for this node as num_subtopics. Return exactly this number of subtopics.
42
+ - Optionally, it may contain human_guidance, which is a string that contains additional instructions for how to generate the subtopics.
43
+ - Optionally, it may contain existing_topics, which is a list of subtopics that already exist at this node. You should not generate subtopics that are in this list.
44
+
45
+
46
+ When generating subtopics, remain somewhat vague. Things can only be tangentially related and they don't have to be interpreted in a single way. Importantly, make sure that the subtopics fit the system prompt.
47
+ """
48
+
49
+
50
+ SAMPLE_GENERATION_PROMPT = """I want to train a large language model and you should help me generate training data for it.
51
+
52
+ Your job is to generate a list of potential inputs to the provided system prompt. They should be diverse and relevant to the system prompt, and the topic if provided.
53
+
54
+ In the user message we'll provide the following:
55
+ - The system prompt as system_prompt
56
+ - A potential topic to generate samples for. This will be a list of strings from most general to most specific. For example, the topic ["Small Talk Topics", "Hobbies", "Cooking"] would represent the topic "Cooking" in the "Hobbies" category of "Small Talk Topics". The list may be empty, in which case you should generate samples using the system prompt alone.
57
+ - The number of samples to generate as num_samples. If greater than 1, generate a range of samples that are diverse and relevant to the system prompt, and the topic if provided.
58
+ - The user message may optionally contain human_guidance, which is a string that contains additional instructions for how to generate the samples.
59
+
60
+ The output must be formatted:
61
+ - in the provided structured format, as an object with a single property "generated_samples" that maps to a list of generated samples that would be inputs to the provided system prompt.
62
+ - With the correct number of samples (num_samples).
63
+ - Do not include any other text or break the schema in any way.
64
+
65
+ Example inputs:
66
+ - system_prompt: "You are an assistant that classifies the tone of a tweet. You should output one of the following labels: 'positive', 'negative', 'neutral'."
67
+ - topic: ["Technology", "New iPhone Event"]
68
+ - num_samples: 2
69
+ Example output: {"generated_samples": ["New iPhone looks amazing! I need that camera.", "Another boring event from Apple.", "New iPhone looks interesting, but I'm waiting for reviews."]}
70
+
71
+
72
+ Note how the output of this task is data to input to the system prompt, not the expected output of the system prompt.
73
+ """
@@ -0,0 +1,117 @@
1
+ import json
2
+
3
+ from kiln_ai.adapters.prompt_builders import SimplePromptBuilder
4
+ from kiln_ai.datamodel import Project, Task
5
+ from pydantic import BaseModel
6
+
7
+ from .data_gen_prompts import (
8
+ SAMPLE_GENERATION_PROMPT,
9
+ TREE_GENERATION_PROMPT,
10
+ )
11
+
12
+
13
+ class DataGenCategoriesTaskInput(BaseModel):
14
+ node_path: list[str]
15
+ system_prompt: str
16
+ num_subtopics: int
17
+ human_guidance: str | None = None
18
+ existing_topics: list[str] | None = None
19
+
20
+ @classmethod
21
+ def from_task(
22
+ cls,
23
+ task: Task,
24
+ node_path: list[str] = [],
25
+ num_subtopics: int = 6,
26
+ human_guidance: str | None = None,
27
+ existing_topics: list[str] | None = None,
28
+ ) -> "DataGenCategoriesTaskInput":
29
+ prompt_builder = SimplePromptBuilder(task=task)
30
+ return cls(
31
+ node_path=node_path,
32
+ num_subtopics=num_subtopics,
33
+ human_guidance=human_guidance,
34
+ existing_topics=existing_topics,
35
+ system_prompt=prompt_builder.build_prompt(),
36
+ )
37
+
38
+
39
+ class DataGenCategoriesTaskOutput(BaseModel):
40
+ subtopics: list[str]
41
+
42
+
43
+ class DataGenCategoriesTask(Task, parent_of={}):
44
+ def __init__(self):
45
+ # Keep the typechecker happy. TODO: shouldn't need this or parent_of above.
46
+ tmp_project = Project(name="DataGen")
47
+ super().__init__(
48
+ name="DataGen",
49
+ parent=tmp_project,
50
+ description="A task which generates synthetic data categories, which in turn are used to generate training data for a model to learn from.",
51
+ instruction=TREE_GENERATION_PROMPT,
52
+ input_json_schema=json.dumps(
53
+ DataGenCategoriesTaskInput.model_json_schema()
54
+ ),
55
+ output_json_schema=json.dumps(
56
+ DataGenCategoriesTaskOutput.model_json_schema()
57
+ ),
58
+ )
59
+
60
+
61
+ class DataGenSampleTaskInput(BaseModel):
62
+ topic: list[str]
63
+ system_prompt: str
64
+ num_samples: int
65
+ human_guidance: str | None = None
66
+
67
+ @classmethod
68
+ def from_task(
69
+ cls,
70
+ task: Task,
71
+ topic: list[str] = [],
72
+ num_samples: int = 8,
73
+ human_guidance: str | None = None,
74
+ ) -> "DataGenSampleTaskInput":
75
+ prompt_builder = SimplePromptBuilder(task=task)
76
+ return cls(
77
+ topic=topic,
78
+ num_samples=num_samples,
79
+ human_guidance=human_guidance,
80
+ system_prompt=prompt_builder.build_prompt(),
81
+ )
82
+
83
+
84
+ def list_json_schema_for_task(task: Task) -> str:
85
+ if task.input_json_schema:
86
+ items_schema = json.loads(task.input_json_schema)
87
+ else:
88
+ items_schema = {"type": "string"}
89
+
90
+ list_schema = {
91
+ "type": "array",
92
+ "items": items_schema,
93
+ }
94
+
95
+ top_level_schema = {
96
+ "type": "object",
97
+ "properties": {
98
+ "generated_samples": list_schema,
99
+ },
100
+ "required": ["generated_samples"],
101
+ }
102
+
103
+ return json.dumps(top_level_schema)
104
+
105
+
106
+ class DataGenSampleTask(Task, parent_of={}):
107
+ def __init__(self, target_task: Task, num_samples: int = 8):
108
+ # Keep the typechecker happy. TODO: shouldn't need this or parent_of above.
109
+ tmp_project = Project(name="DataGenSample")
110
+ super().__init__(
111
+ name="DataGenSample",
112
+ parent=tmp_project,
113
+ description="A task which generates synthetic data samples for a given topic (and optional subtopic).",
114
+ instruction=SAMPLE_GENERATION_PROMPT,
115
+ input_json_schema=json.dumps(DataGenSampleTaskInput.model_json_schema()),
116
+ output_json_schema=list_json_schema_for_task(target_task),
117
+ )
@@ -0,0 +1,292 @@
1
+ import json
2
+
3
+ import pytest
4
+ from kiln_ai.adapters.data_gen.data_gen_task import (
5
+ DataGenCategoriesTask,
6
+ DataGenCategoriesTaskInput,
7
+ DataGenCategoriesTaskOutput,
8
+ DataGenSampleTask,
9
+ DataGenSampleTaskInput,
10
+ list_json_schema_for_task,
11
+ )
12
+ from kiln_ai.adapters.langchain_adapters import LangChainPromptAdapter
13
+ from kiln_ai.adapters.ml_model_list import get_model_and_provider
14
+ from kiln_ai.adapters.test_prompt_adaptors import get_all_models_and_providers
15
+ from kiln_ai.datamodel import Project, Task
16
+
17
+
18
+ @pytest.fixture
19
+ def base_task():
20
+ project = Project(name="TestProject")
21
+ return Task(
22
+ name="Cowboy Speaker",
23
+ parent=project,
24
+ description="Reply like a cowboy",
25
+ instruction="Reply like a cowboy",
26
+ requirements=[],
27
+ )
28
+
29
+
30
+ def test_data_gen_categories_task_input_initialization(base_task):
31
+ # Arrange
32
+ node_path = ["root", "branch", "leaf"]
33
+ num_subtopics = 4
34
+ human_guidance = "Test guidance"
35
+
36
+ # Act
37
+ input_model = DataGenCategoriesTaskInput.from_task(
38
+ task=base_task,
39
+ node_path=node_path,
40
+ num_subtopics=num_subtopics,
41
+ human_guidance=human_guidance,
42
+ )
43
+
44
+ # Assert
45
+ assert input_model.node_path == node_path
46
+ assert input_model.num_subtopics == num_subtopics
47
+ assert input_model.human_guidance == human_guidance
48
+ assert isinstance(input_model.system_prompt, str)
49
+ assert "Reply like a cowboy" in input_model.system_prompt
50
+
51
+
52
+ def test_data_gen_categories_task_input_default_values(base_task):
53
+ # Act
54
+ input_model = DataGenCategoriesTaskInput.from_task(task=base_task)
55
+
56
+ # Assert
57
+ assert input_model.num_subtopics == 6
58
+ assert input_model.human_guidance is None
59
+ assert input_model.node_path == []
60
+
61
+
62
+ def test_data_gen_categories_task_initialization():
63
+ # Act
64
+ task = DataGenCategoriesTask()
65
+
66
+ # Assert
67
+ assert task.name == "DataGen"
68
+ assert isinstance(task.parent, Project)
69
+ assert task.description is not None
70
+ assert task.instruction is not None
71
+ assert isinstance(task.input_json_schema, str)
72
+ assert isinstance(task.output_json_schema, str)
73
+
74
+
75
+ def test_data_gen_categories_task_schemas():
76
+ # Act
77
+ task = DataGenCategoriesTask()
78
+
79
+ # Assert
80
+ input_schema = json.loads(task.input_json_schema)
81
+ output_schema = json.loads(task.output_json_schema)
82
+
83
+ assert isinstance(input_schema, dict)
84
+ assert isinstance(output_schema, dict)
85
+ assert output_schema["type"] == "object"
86
+ assert output_schema["properties"]["subtopics"]["type"] == "array"
87
+ assert input_schema["properties"]["node_path"]["type"] == "array"
88
+ assert input_schema["properties"]["num_subtopics"]["type"] == "integer"
89
+ assert set(input_schema["required"]) == {
90
+ "node_path",
91
+ "num_subtopics",
92
+ "system_prompt",
93
+ }
94
+
95
+
96
+ @pytest.mark.paid
97
+ @pytest.mark.ollama
98
+ @pytest.mark.parametrize("model_name,provider_name", get_all_models_and_providers())
99
+ async def test_data_gen_all_models_providers(
100
+ tmp_path, model_name, provider_name, base_task
101
+ ):
102
+ _, provider = get_model_and_provider(model_name, provider_name)
103
+ if not provider.supports_data_gen:
104
+ # pass if the model doesn't support data gen (testing the support flag is part of this)
105
+ return
106
+
107
+ data_gen_task = DataGenCategoriesTask()
108
+ data_gen_input = DataGenCategoriesTaskInput.from_task(base_task, num_subtopics=6)
109
+
110
+ adapter = LangChainPromptAdapter(
111
+ data_gen_task,
112
+ model_name=model_name,
113
+ provider=provider_name,
114
+ )
115
+
116
+ input_dict = data_gen_input.model_dump()
117
+ run = await adapter.invoke(input_dict)
118
+ parsed_output = DataGenCategoriesTaskOutput.model_validate_json(run.output.output)
119
+ assert len(parsed_output.subtopics) == 6
120
+ for subtopic in parsed_output.subtopics:
121
+ assert isinstance(subtopic, str)
122
+
123
+
124
+ def test_data_gen_sample_task_input_initialization(base_task):
125
+ # Arrange
126
+ topic = ["cowboys", "hats"]
127
+ num_samples = 4
128
+ human_guidance = "Test guidance"
129
+
130
+ # Act
131
+ input_model = DataGenSampleTaskInput.from_task(
132
+ task=base_task,
133
+ topic=topic,
134
+ num_samples=num_samples,
135
+ human_guidance=human_guidance,
136
+ )
137
+
138
+ # Assert
139
+ assert input_model.topic == topic
140
+ assert input_model.num_samples == num_samples
141
+ assert input_model.human_guidance == human_guidance
142
+ assert isinstance(input_model.system_prompt, str)
143
+ assert "Reply like a cowboy" in input_model.system_prompt
144
+
145
+
146
+ def test_data_gen_sample_task_input_default_values(base_task):
147
+ # Act
148
+ input_model = DataGenSampleTaskInput.from_task(task=base_task)
149
+
150
+ # Assert
151
+ assert input_model.num_samples == 8
152
+ assert input_model.human_guidance is None
153
+ assert input_model.topic == []
154
+
155
+
156
+ def test_data_gen_sample_task_initialization(base_task):
157
+ # Act
158
+ task = DataGenSampleTask(target_task=base_task)
159
+
160
+ # Assert
161
+ assert task.name == "DataGenSample"
162
+ assert isinstance(task.parent, Project)
163
+ assert task.description is not None
164
+ assert task.instruction is not None
165
+
166
+ input_schema = json.loads(task.input_json_schema)
167
+ output_schema = json.loads(task.output_json_schema)
168
+
169
+ assert isinstance(input_schema, dict)
170
+ assert isinstance(output_schema, dict)
171
+ assert output_schema["type"] == "object"
172
+ assert output_schema["properties"]["generated_samples"]["type"] == "array"
173
+ assert input_schema["properties"]["topic"]["type"] == "array"
174
+ assert input_schema["properties"]["num_samples"]["type"] == "integer"
175
+ assert set(input_schema["required"]) == {
176
+ "topic",
177
+ "num_samples",
178
+ "system_prompt",
179
+ }
180
+
181
+
182
+ def test_list_json_schema_for_task_with_output_schema(base_task):
183
+ # Arrange
184
+ base_task.input_json_schema = json.dumps(
185
+ {
186
+ "type": "object",
187
+ "properties": {"name": {"type": "string"}, "age": {"type": "integer"}},
188
+ }
189
+ )
190
+
191
+ # Act
192
+ schema = list_json_schema_for_task(base_task)
193
+ parsed_schema = json.loads(schema)
194
+
195
+ # Assert
196
+ assert parsed_schema["type"] == "object"
197
+ generated_samples_schema = parsed_schema["properties"]["generated_samples"]
198
+ assert generated_samples_schema["type"] == "array"
199
+ assert generated_samples_schema["items"]["type"] == "object"
200
+ assert generated_samples_schema["items"]["properties"]["name"]["type"] == "string"
201
+ assert generated_samples_schema["items"]["properties"]["age"]["type"] == "integer"
202
+
203
+
204
+ def test_list_json_schema_for_task_without_output_schema(base_task):
205
+ # Arrange
206
+ base_task.output_json_schema = None
207
+
208
+ # Act
209
+ schema = list_json_schema_for_task(base_task)
210
+ parsed_schema = json.loads(schema)
211
+
212
+ # Assert
213
+ assert parsed_schema["type"] == "object"
214
+ assert parsed_schema["properties"]["generated_samples"]["type"] == "array"
215
+ assert parsed_schema["properties"]["generated_samples"]["items"]["type"] == "string"
216
+
217
+
218
+ @pytest.mark.paid
219
+ @pytest.mark.ollama
220
+ @pytest.mark.parametrize("model_name,provider_name", get_all_models_and_providers())
221
+ async def test_data_gen_sample_all_models_providers(
222
+ tmp_path, model_name, provider_name, base_task
223
+ ):
224
+ _, provider = get_model_and_provider(model_name, provider_name)
225
+ if not provider.supports_data_gen:
226
+ # pass if the model doesn't support data gen (testing the support flag is part of this)
227
+ return
228
+
229
+ data_gen_task = DataGenSampleTask(target_task=base_task)
230
+ data_gen_input = DataGenSampleTaskInput.from_task(
231
+ base_task, topic=["riding horses"], num_samples=4
232
+ )
233
+
234
+ adapter = LangChainPromptAdapter(
235
+ data_gen_task,
236
+ model_name=model_name,
237
+ provider=provider_name,
238
+ )
239
+
240
+ input_dict = data_gen_input.model_dump()
241
+ run = await adapter.invoke(input_dict)
242
+ parsed_output = json.loads(run.output.output)
243
+ samples = parsed_output["generated_samples"]
244
+ assert len(samples) == 4
245
+ for sample in samples:
246
+ assert isinstance(sample, str)
247
+
248
+
249
+ @pytest.mark.paid
250
+ @pytest.mark.ollama
251
+ @pytest.mark.parametrize("model_name,provider_name", get_all_models_and_providers())
252
+ async def test_data_gen_sample_all_models_providers_with_structured_output(
253
+ tmp_path, model_name, provider_name, base_task
254
+ ):
255
+ base_task.output_json_schema = json.dumps(
256
+ {
257
+ "type": "object",
258
+ "properties": {
259
+ "opening": {"type": "string"},
260
+ "closing": {"type": "string"},
261
+ },
262
+ "required": ["opening", "closing"],
263
+ }
264
+ )
265
+
266
+ _, provider = get_model_and_provider(model_name, provider_name)
267
+ if not provider.supports_data_gen:
268
+ # pass if the model doesn't support data gen (testing the support flag is part of this)
269
+ return
270
+
271
+ data_gen_task = DataGenSampleTask(target_task=base_task)
272
+ data_gen_input = DataGenSampleTaskInput.from_task(
273
+ base_task, topic=["riding horses"], num_samples=4
274
+ )
275
+
276
+ adapter = LangChainPromptAdapter(
277
+ data_gen_task,
278
+ model_name=model_name,
279
+ provider=provider_name,
280
+ )
281
+
282
+ input_dict = data_gen_input.model_dump()
283
+ run = await adapter.invoke(input_dict)
284
+ parsed_output = json.loads(run.output.output)
285
+ samples = parsed_output["generated_samples"]
286
+ assert len(samples) == 4
287
+ for sample in samples:
288
+ assert isinstance(sample, dict)
289
+ assert "opening" in sample
290
+ assert "closing" in sample
291
+ assert isinstance(sample["opening"], str)
292
+ assert isinstance(sample["closing"], str)