kiln-ai 0.5.5__py3-none-any.whl → 0.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

@@ -0,0 +1,293 @@
1
+ import json
2
+
3
+ import pytest
4
+
5
+ from kiln_ai.adapters.data_gen.data_gen_task import (
6
+ DataGenCategoriesTask,
7
+ DataGenCategoriesTaskInput,
8
+ DataGenCategoriesTaskOutput,
9
+ DataGenSampleTask,
10
+ DataGenSampleTaskInput,
11
+ list_json_schema_for_task,
12
+ )
13
+ from kiln_ai.adapters.langchain_adapters import LangChainPromptAdapter
14
+ from kiln_ai.adapters.ml_model_list import get_model_and_provider
15
+ from kiln_ai.adapters.test_prompt_adaptors import get_all_models_and_providers
16
+ from kiln_ai.datamodel import Project, Task
17
+
18
+
19
+ @pytest.fixture
20
+ def base_task():
21
+ project = Project(name="TestProject")
22
+ return Task(
23
+ name="Cowboy Speaker",
24
+ parent=project,
25
+ description="Reply like a cowboy",
26
+ instruction="Reply like a cowboy",
27
+ requirements=[],
28
+ )
29
+
30
+
31
+ def test_data_gen_categories_task_input_initialization(base_task):
32
+ # Arrange
33
+ node_path = ["root", "branch", "leaf"]
34
+ num_subtopics = 4
35
+ human_guidance = "Test guidance"
36
+
37
+ # Act
38
+ input_model = DataGenCategoriesTaskInput.from_task(
39
+ task=base_task,
40
+ node_path=node_path,
41
+ num_subtopics=num_subtopics,
42
+ human_guidance=human_guidance,
43
+ )
44
+
45
+ # Assert
46
+ assert input_model.node_path == node_path
47
+ assert input_model.num_subtopics == num_subtopics
48
+ assert input_model.human_guidance == human_guidance
49
+ assert isinstance(input_model.system_prompt, str)
50
+ assert "Reply like a cowboy" in input_model.system_prompt
51
+
52
+
53
+ def test_data_gen_categories_task_input_default_values(base_task):
54
+ # Act
55
+ input_model = DataGenCategoriesTaskInput.from_task(task=base_task)
56
+
57
+ # Assert
58
+ assert input_model.num_subtopics == 6
59
+ assert input_model.human_guidance is None
60
+ assert input_model.node_path == []
61
+
62
+
63
+ def test_data_gen_categories_task_initialization():
64
+ # Act
65
+ task = DataGenCategoriesTask()
66
+
67
+ # Assert
68
+ assert task.name == "DataGen"
69
+ assert isinstance(task.parent, Project)
70
+ assert task.description is not None
71
+ assert task.instruction is not None
72
+ assert isinstance(task.input_json_schema, str)
73
+ assert isinstance(task.output_json_schema, str)
74
+
75
+
76
+ def test_data_gen_categories_task_schemas():
77
+ # Act
78
+ task = DataGenCategoriesTask()
79
+
80
+ # Assert
81
+ input_schema = json.loads(task.input_json_schema)
82
+ output_schema = json.loads(task.output_json_schema)
83
+
84
+ assert isinstance(input_schema, dict)
85
+ assert isinstance(output_schema, dict)
86
+ assert output_schema["type"] == "object"
87
+ assert output_schema["properties"]["subtopics"]["type"] == "array"
88
+ assert input_schema["properties"]["node_path"]["type"] == "array"
89
+ assert input_schema["properties"]["num_subtopics"]["type"] == "integer"
90
+ assert set(input_schema["required"]) == {
91
+ "node_path",
92
+ "num_subtopics",
93
+ "system_prompt",
94
+ }
95
+
96
+
97
+ @pytest.mark.paid
98
+ @pytest.mark.ollama
99
+ @pytest.mark.parametrize("model_name,provider_name", get_all_models_and_providers())
100
+ async def test_data_gen_all_models_providers(
101
+ tmp_path, model_name, provider_name, base_task
102
+ ):
103
+ _, provider = get_model_and_provider(model_name, provider_name)
104
+ if not provider.supports_data_gen:
105
+ # pass if the model doesn't support data gen (testing the support flag is part of this)
106
+ return
107
+
108
+ data_gen_task = DataGenCategoriesTask()
109
+ data_gen_input = DataGenCategoriesTaskInput.from_task(base_task, num_subtopics=6)
110
+
111
+ adapter = LangChainPromptAdapter(
112
+ data_gen_task,
113
+ model_name=model_name,
114
+ provider=provider_name,
115
+ )
116
+
117
+ input_dict = data_gen_input.model_dump()
118
+ run = await adapter.invoke(input_dict)
119
+ parsed_output = DataGenCategoriesTaskOutput.model_validate_json(run.output.output)
120
+ assert len(parsed_output.subtopics) == 6
121
+ for subtopic in parsed_output.subtopics:
122
+ assert isinstance(subtopic, str)
123
+
124
+
125
+ def test_data_gen_sample_task_input_initialization(base_task):
126
+ # Arrange
127
+ topic = ["cowboys", "hats"]
128
+ num_samples = 4
129
+ human_guidance = "Test guidance"
130
+
131
+ # Act
132
+ input_model = DataGenSampleTaskInput.from_task(
133
+ task=base_task,
134
+ topic=topic,
135
+ num_samples=num_samples,
136
+ human_guidance=human_guidance,
137
+ )
138
+
139
+ # Assert
140
+ assert input_model.topic == topic
141
+ assert input_model.num_samples == num_samples
142
+ assert input_model.human_guidance == human_guidance
143
+ assert isinstance(input_model.system_prompt, str)
144
+ assert "Reply like a cowboy" in input_model.system_prompt
145
+
146
+
147
+ def test_data_gen_sample_task_input_default_values(base_task):
148
+ # Act
149
+ input_model = DataGenSampleTaskInput.from_task(task=base_task)
150
+
151
+ # Assert
152
+ assert input_model.num_samples == 8
153
+ assert input_model.human_guidance is None
154
+ assert input_model.topic == []
155
+
156
+
157
+ def test_data_gen_sample_task_initialization(base_task):
158
+ # Act
159
+ task = DataGenSampleTask(target_task=base_task)
160
+
161
+ # Assert
162
+ assert task.name == "DataGenSample"
163
+ assert isinstance(task.parent, Project)
164
+ assert task.description is not None
165
+ assert task.instruction is not None
166
+
167
+ input_schema = json.loads(task.input_json_schema)
168
+ output_schema = json.loads(task.output_json_schema)
169
+
170
+ assert isinstance(input_schema, dict)
171
+ assert isinstance(output_schema, dict)
172
+ assert output_schema["type"] == "object"
173
+ assert output_schema["properties"]["generated_samples"]["type"] == "array"
174
+ assert input_schema["properties"]["topic"]["type"] == "array"
175
+ assert input_schema["properties"]["num_samples"]["type"] == "integer"
176
+ assert set(input_schema["required"]) == {
177
+ "topic",
178
+ "num_samples",
179
+ "system_prompt",
180
+ }
181
+
182
+
183
+ def test_list_json_schema_for_task_with_output_schema(base_task):
184
+ # Arrange
185
+ base_task.input_json_schema = json.dumps(
186
+ {
187
+ "type": "object",
188
+ "properties": {"name": {"type": "string"}, "age": {"type": "integer"}},
189
+ }
190
+ )
191
+
192
+ # Act
193
+ schema = list_json_schema_for_task(base_task)
194
+ parsed_schema = json.loads(schema)
195
+
196
+ # Assert
197
+ assert parsed_schema["type"] == "object"
198
+ generated_samples_schema = parsed_schema["properties"]["generated_samples"]
199
+ assert generated_samples_schema["type"] == "array"
200
+ assert generated_samples_schema["items"]["type"] == "object"
201
+ assert generated_samples_schema["items"]["properties"]["name"]["type"] == "string"
202
+ assert generated_samples_schema["items"]["properties"]["age"]["type"] == "integer"
203
+
204
+
205
+ def test_list_json_schema_for_task_without_output_schema(base_task):
206
+ # Arrange
207
+ base_task.output_json_schema = None
208
+
209
+ # Act
210
+ schema = list_json_schema_for_task(base_task)
211
+ parsed_schema = json.loads(schema)
212
+
213
+ # Assert
214
+ assert parsed_schema["type"] == "object"
215
+ assert parsed_schema["properties"]["generated_samples"]["type"] == "array"
216
+ assert parsed_schema["properties"]["generated_samples"]["items"]["type"] == "string"
217
+
218
+
219
+ @pytest.mark.paid
220
+ @pytest.mark.ollama
221
+ @pytest.mark.parametrize("model_name,provider_name", get_all_models_and_providers())
222
+ async def test_data_gen_sample_all_models_providers(
223
+ tmp_path, model_name, provider_name, base_task
224
+ ):
225
+ _, provider = get_model_and_provider(model_name, provider_name)
226
+ if not provider.supports_data_gen:
227
+ # pass if the model doesn't support data gen (testing the support flag is part of this)
228
+ return
229
+
230
+ data_gen_task = DataGenSampleTask(target_task=base_task)
231
+ data_gen_input = DataGenSampleTaskInput.from_task(
232
+ base_task, topic=["riding horses"], num_samples=4
233
+ )
234
+
235
+ adapter = LangChainPromptAdapter(
236
+ data_gen_task,
237
+ model_name=model_name,
238
+ provider=provider_name,
239
+ )
240
+
241
+ input_dict = data_gen_input.model_dump()
242
+ run = await adapter.invoke(input_dict)
243
+ parsed_output = json.loads(run.output.output)
244
+ samples = parsed_output["generated_samples"]
245
+ assert len(samples) == 4
246
+ for sample in samples:
247
+ assert isinstance(sample, str)
248
+
249
+
250
+ @pytest.mark.paid
251
+ @pytest.mark.ollama
252
+ @pytest.mark.parametrize("model_name,provider_name", get_all_models_and_providers())
253
+ async def test_data_gen_sample_all_models_providers_with_structured_output(
254
+ tmp_path, model_name, provider_name, base_task
255
+ ):
256
+ base_task.output_json_schema = json.dumps(
257
+ {
258
+ "type": "object",
259
+ "properties": {
260
+ "opening": {"type": "string"},
261
+ "closing": {"type": "string"},
262
+ },
263
+ "required": ["opening", "closing"],
264
+ }
265
+ )
266
+
267
+ _, provider = get_model_and_provider(model_name, provider_name)
268
+ if not provider.supports_data_gen:
269
+ # pass if the model doesn't support data gen (testing the support flag is part of this)
270
+ return
271
+
272
+ data_gen_task = DataGenSampleTask(target_task=base_task)
273
+ data_gen_input = DataGenSampleTaskInput.from_task(
274
+ base_task, topic=["riding horses"], num_samples=4
275
+ )
276
+
277
+ adapter = LangChainPromptAdapter(
278
+ data_gen_task,
279
+ model_name=model_name,
280
+ provider=provider_name,
281
+ )
282
+
283
+ input_dict = data_gen_input.model_dump()
284
+ run = await adapter.invoke(input_dict)
285
+ parsed_output = json.loads(run.output.output)
286
+ samples = parsed_output["generated_samples"]
287
+ assert len(samples) == 4
288
+ for sample in samples:
289
+ assert isinstance(sample, dict)
290
+ assert "opening" in sample
291
+ assert "closing" in sample
292
+ assert isinstance(sample["opening"], str)
293
+ assert isinstance(sample["closing"], str)
@@ -2,14 +2,14 @@ from typing import Dict
2
2
 
3
3
  from langchain_core.language_models import LanguageModelInput
4
4
  from langchain_core.language_models.chat_models import BaseChatModel
5
- from langchain_core.messages import HumanMessage, SystemMessage
5
+ from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
6
6
  from langchain_core.messages.base import BaseMessage
7
7
  from langchain_core.runnables import Runnable
8
8
  from pydantic import BaseModel
9
9
 
10
10
  import kiln_ai.datamodel as datamodel
11
11
 
12
- from .base_adapter import AdapterInfo, BaseAdapter, BasePromptBuilder
12
+ from .base_adapter import AdapterInfo, BaseAdapter, BasePromptBuilder, RunOutput
13
13
  from .ml_model_list import langchain_model_from
14
14
 
15
15
  LangChainModelType = BaseChatModel | Runnable[LanguageModelInput, Dict | BaseModel]
@@ -84,15 +84,41 @@ class LangChainPromptAdapter(BaseAdapter):
84
84
  )
85
85
  return self._model
86
86
 
87
- async def _run(self, input: Dict | str) -> Dict | str:
87
+ async def _run(self, input: Dict | str) -> RunOutput:
88
+ model = await self.model()
89
+ chain = model
90
+ intermediate_outputs = {}
91
+
88
92
  prompt = self.build_prompt()
89
93
  user_msg = self.prompt_builder.build_user_message(input)
90
94
  messages = [
91
95
  SystemMessage(content=prompt),
92
96
  HumanMessage(content=user_msg),
93
97
  ]
94
- model = await self.model()
95
- response = model.invoke(messages)
98
+
99
+ # COT with structured output
100
+ cot_prompt = self.prompt_builder.chain_of_thought_prompt()
101
+ if cot_prompt and self.has_structured_output():
102
+ # Base model (without structured output) used for COT message
103
+ base_model = await langchain_model_from(
104
+ self.model_name, self.model_provider
105
+ )
106
+ messages.append(
107
+ SystemMessage(content=cot_prompt),
108
+ )
109
+
110
+ cot_messages = [*messages]
111
+ cot_response = base_model.invoke(cot_messages)
112
+ intermediate_outputs["chain_of_thought"] = cot_response.content
113
+ messages.append(AIMessage(content=cot_response.content))
114
+ messages.append(
115
+ SystemMessage(content="Considering the above, return a final result.")
116
+ )
117
+ elif cot_prompt:
118
+ # for plaintext output, we just add COT instructions. We still only make one call.
119
+ messages.append(SystemMessage(content=cot_prompt))
120
+
121
+ response = chain.invoke(messages)
96
122
 
97
123
  if self.has_structured_output():
98
124
  if (
@@ -102,14 +128,20 @@ class LangChainPromptAdapter(BaseAdapter):
102
128
  ):
103
129
  raise RuntimeError(f"structured response not returned: {response}")
104
130
  structured_response = response["parsed"]
105
- return self._munge_response(structured_response)
131
+ return RunOutput(
132
+ output=self._munge_response(structured_response),
133
+ intermediate_outputs=intermediate_outputs,
134
+ )
106
135
  else:
107
136
  if not isinstance(response, BaseMessage):
108
137
  raise RuntimeError(f"response is not a BaseMessage: {response}")
109
138
  text_content = response.content
110
139
  if not isinstance(text_content, str):
111
140
  raise RuntimeError(f"response is not a string: {text_content}")
112
- return text_content
141
+ return RunOutput(
142
+ output=text_content,
143
+ intermediate_outputs=intermediate_outputs,
144
+ )
113
145
 
114
146
  def adapter_info(self) -> AdapterInfo:
115
147
  return AdapterInfo(
@@ -89,6 +89,7 @@ class KilnModelProvider(BaseModel):
89
89
 
90
90
  name: ModelProviderName
91
91
  supports_structured_output: bool = True
92
+ supports_data_gen: bool = True
92
93
  provider_options: Dict = {}
93
94
 
94
95
 
@@ -176,6 +177,8 @@ built_in_models: List[KilnModel] = [
176
177
  providers=[
177
178
  KilnModelProvider(
178
179
  name=ModelProviderName.openrouter,
180
+ supports_structured_output=False, # it should, but doesn't work on openrouter
181
+ supports_data_gen=False, # doesn't work on openrouter
179
182
  provider_options={"model": "google/gemini-pro-1.5"},
180
183
  ),
181
184
  ],
@@ -188,6 +191,7 @@ built_in_models: List[KilnModel] = [
188
191
  providers=[
189
192
  KilnModelProvider(
190
193
  name=ModelProviderName.openrouter,
194
+ supports_data_gen=False,
191
195
  provider_options={"model": "google/gemini-flash-1.5"},
192
196
  ),
193
197
  ],
@@ -200,6 +204,8 @@ built_in_models: List[KilnModel] = [
200
204
  providers=[
201
205
  KilnModelProvider(
202
206
  name=ModelProviderName.openrouter,
207
+ supports_structured_output=False,
208
+ supports_data_gen=False,
203
209
  provider_options={"model": "google/gemini-flash-1.5-8b"},
204
210
  ),
205
211
  ],
@@ -213,6 +219,7 @@ built_in_models: List[KilnModel] = [
213
219
  KilnModelProvider(
214
220
  name=ModelProviderName.openrouter,
215
221
  supports_structured_output=False,
222
+ supports_data_gen=False,
216
223
  provider_options={"model": "nvidia/llama-3.1-nemotron-70b-instruct"},
217
224
  ),
218
225
  ],
@@ -230,6 +237,7 @@ built_in_models: List[KilnModel] = [
230
237
  KilnModelProvider(
231
238
  name=ModelProviderName.amazon_bedrock,
232
239
  supports_structured_output=False,
240
+ supports_data_gen=False,
233
241
  provider_options={
234
242
  "model": "meta.llama3-1-8b-instruct-v1:0",
235
243
  "region_name": "us-west-2", # Llama 3.1 only in west-2
@@ -237,6 +245,7 @@ built_in_models: List[KilnModel] = [
237
245
  ),
238
246
  KilnModelProvider(
239
247
  name=ModelProviderName.ollama,
248
+ supports_data_gen=False,
240
249
  provider_options={
241
250
  "model": "llama3.1:8b",
242
251
  "model_aliases": ["llama3.1"], # 8b is default
@@ -245,6 +254,7 @@ built_in_models: List[KilnModel] = [
245
254
  KilnModelProvider(
246
255
  name=ModelProviderName.openrouter,
247
256
  supports_structured_output=False,
257
+ supports_data_gen=False,
248
258
  provider_options={"model": "meta-llama/llama-3.1-8b-instruct"},
249
259
  ),
250
260
  ],
@@ -261,7 +271,9 @@ built_in_models: List[KilnModel] = [
261
271
  ),
262
272
  KilnModelProvider(
263
273
  name=ModelProviderName.amazon_bedrock,
274
+ # not sure how AWS manages to break this, but it's not working
264
275
  supports_structured_output=False,
276
+ supports_data_gen=False,
265
277
  provider_options={
266
278
  "model": "meta.llama3-1-70b-instruct-v1:0",
267
279
  "region_name": "us-west-2", # Llama 3.1 only in west-2
@@ -285,6 +297,7 @@ built_in_models: List[KilnModel] = [
285
297
  providers=[
286
298
  KilnModelProvider(
287
299
  name=ModelProviderName.amazon_bedrock,
300
+ supports_data_gen=False,
288
301
  provider_options={
289
302
  "model": "meta.llama3-1-405b-instruct-v1:0",
290
303
  "region_name": "us-west-2", # Llama 3.1 only in west-2
@@ -344,8 +357,15 @@ built_in_models: List[KilnModel] = [
344
357
  KilnModelProvider(
345
358
  name=ModelProviderName.openrouter,
346
359
  supports_structured_output=False,
360
+ supports_data_gen=False,
347
361
  provider_options={"model": "meta-llama/llama-3.2-3b-instruct"},
348
362
  ),
363
+ KilnModelProvider(
364
+ name=ModelProviderName.ollama,
365
+ supports_structured_output=False,
366
+ supports_data_gen=False,
367
+ provider_options={"model": "llama3.2"},
368
+ ),
349
369
  ],
350
370
  ),
351
371
  # Llama 3.2 11B
@@ -357,8 +377,15 @@ built_in_models: List[KilnModel] = [
357
377
  KilnModelProvider(
358
378
  name=ModelProviderName.openrouter,
359
379
  supports_structured_output=False,
380
+ supports_data_gen=False,
360
381
  provider_options={"model": "meta-llama/llama-3.2-11b-vision-instruct"},
361
382
  ),
383
+ KilnModelProvider(
384
+ name=ModelProviderName.ollama,
385
+ supports_structured_output=False,
386
+ supports_data_gen=False,
387
+ provider_options={"model": "llama3.2-vision"},
388
+ ),
362
389
  ],
363
390
  ),
364
391
  # Llama 3.2 90B
@@ -370,8 +397,15 @@ built_in_models: List[KilnModel] = [
370
397
  KilnModelProvider(
371
398
  name=ModelProviderName.openrouter,
372
399
  supports_structured_output=False,
400
+ supports_data_gen=False,
373
401
  provider_options={"model": "meta-llama/llama-3.2-90b-vision-instruct"},
374
402
  ),
403
+ KilnModelProvider(
404
+ name=ModelProviderName.ollama,
405
+ supports_structured_output=False,
406
+ supports_data_gen=False,
407
+ provider_options={"model": "llama3.2-vision:90b"},
408
+ ),
375
409
  ],
376
410
  ),
377
411
  # Phi 3.5
@@ -384,10 +418,13 @@ built_in_models: List[KilnModel] = [
384
418
  KilnModelProvider(
385
419
  name=ModelProviderName.ollama,
386
420
  supports_structured_output=False,
421
+ supports_data_gen=False,
387
422
  provider_options={"model": "phi3.5"},
388
423
  ),
389
424
  KilnModelProvider(
390
425
  name=ModelProviderName.openrouter,
426
+ supports_structured_output=False,
427
+ supports_data_gen=False,
391
428
  provider_options={"model": "microsoft/phi-3.5-mini-128k-instruct"},
392
429
  ),
393
430
  ],
@@ -402,6 +439,7 @@ built_in_models: List[KilnModel] = [
402
439
  KilnModelProvider(
403
440
  name=ModelProviderName.ollama,
404
441
  supports_structured_output=False,
442
+ supports_data_gen=False,
405
443
  provider_options={
406
444
  "model": "gemma2:2b",
407
445
  },
@@ -417,12 +455,14 @@ built_in_models: List[KilnModel] = [
417
455
  providers=[
418
456
  KilnModelProvider(
419
457
  name=ModelProviderName.ollama,
458
+ supports_data_gen=False,
420
459
  provider_options={
421
460
  "model": "gemma2:9b",
422
461
  },
423
462
  ),
424
463
  KilnModelProvider(
425
464
  name=ModelProviderName.openrouter,
465
+ supports_data_gen=False,
426
466
  provider_options={"model": "google/gemma-2-9b-it"},
427
467
  ),
428
468
  ],
@@ -436,12 +476,14 @@ built_in_models: List[KilnModel] = [
436
476
  providers=[
437
477
  KilnModelProvider(
438
478
  name=ModelProviderName.ollama,
479
+ supports_data_gen=False,
439
480
  provider_options={
440
481
  "model": "gemma2:27b",
441
482
  },
442
483
  ),
443
484
  KilnModelProvider(
444
485
  name=ModelProviderName.openrouter,
486
+ supports_data_gen=False,
445
487
  provider_options={"model": "google/gemma-2-27b-it"},
446
488
  ),
447
489
  ],
@@ -449,6 +491,19 @@ built_in_models: List[KilnModel] = [
449
491
  ]
450
492
 
451
493
 
494
+ def get_model_and_provider(
495
+ model_name: str, provider_name: str
496
+ ) -> tuple[KilnModel | None, KilnModelProvider | None]:
497
+ model = next(filter(lambda m: m.name == model_name, built_in_models), None)
498
+ if model is None:
499
+ return None, None
500
+ provider = next(filter(lambda p: p.name == provider_name, model.providers), None)
501
+ # all or nothing
502
+ if provider is None or model is None:
503
+ return None, None
504
+ return model, provider
505
+
506
+
452
507
  def provider_name_from_id(id: str) -> str:
453
508
  """
454
509
  Converts a provider ID to its human-readable name.
@@ -687,7 +742,6 @@ def parse_ollama_tags(tags: Any) -> OllamaConnection | None:
687
742
  models = tags["models"]
688
743
  if isinstance(models, list):
689
744
  model_names = [model["model"] for model in models]
690
- print(f"model_names: {model_names}")
691
745
  available_supported_models = [
692
746
  model
693
747
  for model in model_names
@@ -54,6 +54,28 @@ class BasePromptBuilder(metaclass=ABCMeta):
54
54
 
55
55
  return f"The input is:\n{input}"
56
56
 
57
+ def chain_of_thought_prompt(self) -> str | None:
58
+ """Build and return the chain of thought prompt string.
59
+
60
+ Returns:
61
+ str: The constructed chain of thought prompt.
62
+ """
63
+ return None
64
+
65
+ def build_prompt_for_ui(self) -> str:
66
+ """Build a prompt for the UI. It includes additional instructions (like chain of thought), even if they are passed to the model in stages.
67
+
68
+ Designed for end-user consumption, not for model consumption.
69
+
70
+ Returns:
71
+ str: The constructed prompt string.
72
+ """
73
+ base_prompt = self.build_prompt()
74
+ cot_prompt = self.chain_of_thought_prompt()
75
+ if cot_prompt:
76
+ base_prompt += "\n# Thinking Instructions\n\n" + cot_prompt
77
+ return base_prompt
78
+
57
79
 
58
80
  class SimplePromptBuilder(BasePromptBuilder):
59
81
  """A basic prompt builder that combines task instruction with requirements."""
@@ -187,11 +209,49 @@ class RepairsPromptBuilder(MultiShotPromptBuilder):
187
209
  return prompt_section
188
210
 
189
211
 
212
+ def chain_of_thought_prompt(task: Task) -> str | None:
213
+ """Standard implementation to build and return the chain of thought prompt string.
214
+
215
+ Returns:
216
+ str: The constructed chain of thought prompt.
217
+ """
218
+
219
+ cot_instruction = task.thinking_instruction
220
+ if not cot_instruction:
221
+ cot_instruction = "Think step by step, explaining your reasoning."
222
+
223
+ return cot_instruction
224
+
225
+
226
+ class SimpleChainOfThoughtPromptBuilder(SimplePromptBuilder):
227
+ """A prompt builder that includes a chain of thought prompt on top of the simple prompt."""
228
+
229
+ def chain_of_thought_prompt(self) -> str | None:
230
+ return chain_of_thought_prompt(self.task)
231
+
232
+
233
+ class FewShotChainOfThoughtPromptBuilder(FewShotPromptBuilder):
234
+ """A prompt builder that includes a chain of thought prompt on top of the few shot prompt."""
235
+
236
+ def chain_of_thought_prompt(self) -> str | None:
237
+ return chain_of_thought_prompt(self.task)
238
+
239
+
240
+ class MultiShotChainOfThoughtPromptBuilder(MultiShotPromptBuilder):
241
+ """A prompt builder that includes a chain of thought prompt on top of the multi shot prompt."""
242
+
243
+ def chain_of_thought_prompt(self) -> str | None:
244
+ return chain_of_thought_prompt(self.task)
245
+
246
+
190
247
  prompt_builder_registry = {
191
248
  "simple_prompt_builder": SimplePromptBuilder,
192
249
  "multi_shot_prompt_builder": MultiShotPromptBuilder,
193
250
  "few_shot_prompt_builder": FewShotPromptBuilder,
194
251
  "repairs_prompt_builder": RepairsPromptBuilder,
252
+ "simple_chain_of_thought_prompt_builder": SimpleChainOfThoughtPromptBuilder,
253
+ "few_shot_chain_of_thought_prompt_builder": FewShotChainOfThoughtPromptBuilder,
254
+ "multi_shot_chain_of_thought_prompt_builder": MultiShotChainOfThoughtPromptBuilder,
195
255
  }
196
256
 
197
257
 
@@ -217,5 +277,11 @@ def prompt_builder_from_ui_name(ui_name: str) -> type[BasePromptBuilder]:
217
277
  return MultiShotPromptBuilder
218
278
  case "repairs":
219
279
  return RepairsPromptBuilder
280
+ case "simple_chain_of_thought":
281
+ return SimpleChainOfThoughtPromptBuilder
282
+ case "few_shot_chain_of_thought":
283
+ return FewShotChainOfThoughtPromptBuilder
284
+ case "multi_shot_chain_of_thought":
285
+ return MultiShotChainOfThoughtPromptBuilder
220
286
  case _:
221
287
  raise ValueError(f"Unknown prompt builder: {ui_name}")
@@ -5,6 +5,7 @@ from unittest.mock import AsyncMock, patch
5
5
  import pytest
6
6
  from pydantic import ValidationError
7
7
 
8
+ from kiln_ai.adapters.base_adapter import RunOutput
8
9
  from kiln_ai.adapters.langchain_adapters import (
9
10
  LangChainPromptAdapter,
10
11
  )
@@ -222,7 +223,9 @@ async def test_mocked_repair_task_run(sample_task, sample_task_run, sample_repai
222
223
  with patch.object(
223
224
  LangChainPromptAdapter, "_run", new_callable=AsyncMock
224
225
  ) as mock_run:
225
- mock_run.return_value = mocked_output
226
+ mock_run.return_value = RunOutput(
227
+ output=mocked_output, intermediate_outputs=None
228
+ )
226
229
 
227
230
  adapter = LangChainPromptAdapter(
228
231
  repair_task, model_name="llama_3_1_8b", provider="groq"