kiln-ai 0.5.4__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

@@ -69,6 +69,7 @@ class ModelName(str, Enum):
69
69
  gemma_2_2b = "gemma_2_2b"
70
70
  gemma_2_9b = "gemma_2_9b"
71
71
  gemma_2_27b = "gemma_2_27b"
72
+ claude_3_5_haiku = "claude_3_5_haiku"
72
73
  claude_3_5_sonnet = "claude_3_5_sonnet"
73
74
  gemini_1_5_flash = "gemini_1_5_flash"
74
75
  gemini_1_5_flash_8b = "gemini_1_5_flash_8b"
@@ -88,6 +89,7 @@ class KilnModelProvider(BaseModel):
88
89
 
89
90
  name: ModelProviderName
90
91
  supports_structured_output: bool = True
92
+ supports_data_gen: bool = True
91
93
  provider_options: Dict = {}
92
94
 
93
95
 
@@ -143,6 +145,18 @@ built_in_models: List[KilnModel] = [
143
145
  ),
144
146
  ],
145
147
  ),
148
+ # Claude 3.5 Haiku
149
+ KilnModel(
150
+ family=ModelFamily.claude,
151
+ name=ModelName.claude_3_5_haiku,
152
+ friendly_name="Claude 3.5 Haiku",
153
+ providers=[
154
+ KilnModelProvider(
155
+ name=ModelProviderName.openrouter,
156
+ provider_options={"model": "anthropic/claude-3-5-haiku"},
157
+ ),
158
+ ],
159
+ ),
146
160
  # Claude 3.5 Sonnet
147
161
  KilnModel(
148
162
  family=ModelFamily.claude,
@@ -163,6 +177,8 @@ built_in_models: List[KilnModel] = [
163
177
  providers=[
164
178
  KilnModelProvider(
165
179
  name=ModelProviderName.openrouter,
180
+ supports_structured_output=False, # it should, but doesn't work on openrouter
181
+ supports_data_gen=False, # doesn't work on openrouter
166
182
  provider_options={"model": "google/gemini-pro-1.5"},
167
183
  ),
168
184
  ],
@@ -175,6 +191,7 @@ built_in_models: List[KilnModel] = [
175
191
  providers=[
176
192
  KilnModelProvider(
177
193
  name=ModelProviderName.openrouter,
194
+ supports_data_gen=False,
178
195
  provider_options={"model": "google/gemini-flash-1.5"},
179
196
  ),
180
197
  ],
@@ -187,6 +204,8 @@ built_in_models: List[KilnModel] = [
187
204
  providers=[
188
205
  KilnModelProvider(
189
206
  name=ModelProviderName.openrouter,
207
+ supports_structured_output=False,
208
+ supports_data_gen=False,
190
209
  provider_options={"model": "google/gemini-flash-1.5-8b"},
191
210
  ),
192
211
  ],
@@ -200,6 +219,7 @@ built_in_models: List[KilnModel] = [
200
219
  KilnModelProvider(
201
220
  name=ModelProviderName.openrouter,
202
221
  supports_structured_output=False,
222
+ supports_data_gen=False,
203
223
  provider_options={"model": "nvidia/llama-3.1-nemotron-70b-instruct"},
204
224
  ),
205
225
  ],
@@ -217,6 +237,7 @@ built_in_models: List[KilnModel] = [
217
237
  KilnModelProvider(
218
238
  name=ModelProviderName.amazon_bedrock,
219
239
  supports_structured_output=False,
240
+ supports_data_gen=False,
220
241
  provider_options={
221
242
  "model": "meta.llama3-1-8b-instruct-v1:0",
222
243
  "region_name": "us-west-2", # Llama 3.1 only in west-2
@@ -224,6 +245,7 @@ built_in_models: List[KilnModel] = [
224
245
  ),
225
246
  KilnModelProvider(
226
247
  name=ModelProviderName.ollama,
248
+ supports_data_gen=False,
227
249
  provider_options={
228
250
  "model": "llama3.1:8b",
229
251
  "model_aliases": ["llama3.1"], # 8b is default
@@ -232,6 +254,7 @@ built_in_models: List[KilnModel] = [
232
254
  KilnModelProvider(
233
255
  name=ModelProviderName.openrouter,
234
256
  supports_structured_output=False,
257
+ supports_data_gen=False,
235
258
  provider_options={"model": "meta-llama/llama-3.1-8b-instruct"},
236
259
  ),
237
260
  ],
@@ -248,7 +271,9 @@ built_in_models: List[KilnModel] = [
248
271
  ),
249
272
  KilnModelProvider(
250
273
  name=ModelProviderName.amazon_bedrock,
274
+ # not sure how AWS manages to break this, but it's not working
251
275
  supports_structured_output=False,
276
+ supports_data_gen=False,
252
277
  provider_options={
253
278
  "model": "meta.llama3-1-70b-instruct-v1:0",
254
279
  "region_name": "us-west-2", # Llama 3.1 only in west-2
@@ -272,6 +297,7 @@ built_in_models: List[KilnModel] = [
272
297
  providers=[
273
298
  KilnModelProvider(
274
299
  name=ModelProviderName.amazon_bedrock,
300
+ supports_data_gen=False,
275
301
  provider_options={
276
302
  "model": "meta.llama3-1-405b-instruct-v1:0",
277
303
  "region_name": "us-west-2", # Llama 3.1 only in west-2
@@ -331,8 +357,15 @@ built_in_models: List[KilnModel] = [
331
357
  KilnModelProvider(
332
358
  name=ModelProviderName.openrouter,
333
359
  supports_structured_output=False,
360
+ supports_data_gen=False,
334
361
  provider_options={"model": "meta-llama/llama-3.2-3b-instruct"},
335
362
  ),
363
+ KilnModelProvider(
364
+ name=ModelProviderName.ollama,
365
+ supports_structured_output=False,
366
+ supports_data_gen=False,
367
+ provider_options={"model": "llama3.2"},
368
+ ),
336
369
  ],
337
370
  ),
338
371
  # Llama 3.2 11B
@@ -344,8 +377,15 @@ built_in_models: List[KilnModel] = [
344
377
  KilnModelProvider(
345
378
  name=ModelProviderName.openrouter,
346
379
  supports_structured_output=False,
380
+ supports_data_gen=False,
347
381
  provider_options={"model": "meta-llama/llama-3.2-11b-vision-instruct"},
348
382
  ),
383
+ KilnModelProvider(
384
+ name=ModelProviderName.ollama,
385
+ supports_structured_output=False,
386
+ supports_data_gen=False,
387
+ provider_options={"model": "llama3.2-vision"},
388
+ ),
349
389
  ],
350
390
  ),
351
391
  # Llama 3.2 90B
@@ -357,8 +397,15 @@ built_in_models: List[KilnModel] = [
357
397
  KilnModelProvider(
358
398
  name=ModelProviderName.openrouter,
359
399
  supports_structured_output=False,
400
+ supports_data_gen=False,
360
401
  provider_options={"model": "meta-llama/llama-3.2-90b-vision-instruct"},
361
402
  ),
403
+ KilnModelProvider(
404
+ name=ModelProviderName.ollama,
405
+ supports_structured_output=False,
406
+ supports_data_gen=False,
407
+ provider_options={"model": "llama3.2-vision:90b"},
408
+ ),
362
409
  ],
363
410
  ),
364
411
  # Phi 3.5
@@ -371,10 +418,13 @@ built_in_models: List[KilnModel] = [
371
418
  KilnModelProvider(
372
419
  name=ModelProviderName.ollama,
373
420
  supports_structured_output=False,
421
+ supports_data_gen=False,
374
422
  provider_options={"model": "phi3.5"},
375
423
  ),
376
424
  KilnModelProvider(
377
425
  name=ModelProviderName.openrouter,
426
+ supports_structured_output=False,
427
+ supports_data_gen=False,
378
428
  provider_options={"model": "microsoft/phi-3.5-mini-128k-instruct"},
379
429
  ),
380
430
  ],
@@ -389,6 +439,7 @@ built_in_models: List[KilnModel] = [
389
439
  KilnModelProvider(
390
440
  name=ModelProviderName.ollama,
391
441
  supports_structured_output=False,
442
+ supports_data_gen=False,
392
443
  provider_options={
393
444
  "model": "gemma2:2b",
394
445
  },
@@ -404,12 +455,14 @@ built_in_models: List[KilnModel] = [
404
455
  providers=[
405
456
  KilnModelProvider(
406
457
  name=ModelProviderName.ollama,
458
+ supports_data_gen=False,
407
459
  provider_options={
408
460
  "model": "gemma2:9b",
409
461
  },
410
462
  ),
411
463
  KilnModelProvider(
412
464
  name=ModelProviderName.openrouter,
465
+ supports_data_gen=False,
413
466
  provider_options={"model": "google/gemma-2-9b-it"},
414
467
  ),
415
468
  ],
@@ -423,12 +476,14 @@ built_in_models: List[KilnModel] = [
423
476
  providers=[
424
477
  KilnModelProvider(
425
478
  name=ModelProviderName.ollama,
479
+ supports_data_gen=False,
426
480
  provider_options={
427
481
  "model": "gemma2:27b",
428
482
  },
429
483
  ),
430
484
  KilnModelProvider(
431
485
  name=ModelProviderName.openrouter,
486
+ supports_data_gen=False,
432
487
  provider_options={"model": "google/gemma-2-27b-it"},
433
488
  ),
434
489
  ],
@@ -436,6 +491,19 @@ built_in_models: List[KilnModel] = [
436
491
  ]
437
492
 
438
493
 
494
+ def get_model_and_provider(
495
+ model_name: str, provider_name: str
496
+ ) -> tuple[KilnModel | None, KilnModelProvider | None]:
497
+ model = next(filter(lambda m: m.name == model_name, built_in_models), None)
498
+ if model is None:
499
+ return None, None
500
+ provider = next(filter(lambda p: p.name == provider_name, model.providers), None)
501
+ # all or nothing
502
+ if provider is None or model is None:
503
+ return None, None
504
+ return model, provider
505
+
506
+
439
507
  def provider_name_from_id(id: str) -> str:
440
508
  """
441
509
  Converts a provider ID to its human-readable name.
@@ -674,7 +742,6 @@ def parse_ollama_tags(tags: Any) -> OllamaConnection | None:
674
742
  models = tags["models"]
675
743
  if isinstance(models, list):
676
744
  model_names = [model["model"] for model in models]
677
- print(f"model_names: {model_names}")
678
745
  available_supported_models = [
679
746
  model
680
747
  for model in model_names
@@ -54,6 +54,28 @@ class BasePromptBuilder(metaclass=ABCMeta):
54
54
 
55
55
  return f"The input is:\n{input}"
56
56
 
57
+ def chain_of_thought_prompt(self) -> str | None:
58
+ """Build and return the chain of thought prompt string.
59
+
60
+ Returns:
61
+ str: The constructed chain of thought prompt.
62
+ """
63
+ return None
64
+
65
+ def build_prompt_for_ui(self) -> str:
66
+ """Build a prompt for the UI. It includes additional instructions (like chain of thought), even if they are passed to the model in stages.
67
+
68
+ Designed for end-user consumption, not for model consumption.
69
+
70
+ Returns:
71
+ str: The constructed prompt string.
72
+ """
73
+ base_prompt = self.build_prompt()
74
+ cot_prompt = self.chain_of_thought_prompt()
75
+ if cot_prompt:
76
+ base_prompt += "\n# Thinking Instructions\n\n" + cot_prompt
77
+ return base_prompt
78
+
57
79
 
58
80
  class SimplePromptBuilder(BasePromptBuilder):
59
81
  """A basic prompt builder that combines task instruction with requirements."""
@@ -187,11 +209,49 @@ class RepairsPromptBuilder(MultiShotPromptBuilder):
187
209
  return prompt_section
188
210
 
189
211
 
212
+ def chain_of_thought_prompt(task: Task) -> str | None:
213
+ """Standard implementation to build and return the chain of thought prompt string.
214
+
215
+ Returns:
216
+ str: The constructed chain of thought prompt.
217
+ """
218
+
219
+ cot_instruction = task.thinking_instruction
220
+ if not cot_instruction:
221
+ cot_instruction = "Think step by step, explaining your reasoning."
222
+
223
+ return cot_instruction
224
+
225
+
226
+ class SimpleChainOfThoughtPromptBuilder(SimplePromptBuilder):
227
+ """A prompt builder that includes a chain of thought prompt on top of the simple prompt."""
228
+
229
+ def chain_of_thought_prompt(self) -> str | None:
230
+ return chain_of_thought_prompt(self.task)
231
+
232
+
233
+ class FewShotChainOfThoughtPromptBuilder(FewShotPromptBuilder):
234
+ """A prompt builder that includes a chain of thought prompt on top of the few shot prompt."""
235
+
236
+ def chain_of_thought_prompt(self) -> str | None:
237
+ return chain_of_thought_prompt(self.task)
238
+
239
+
240
+ class MultiShotChainOfThoughtPromptBuilder(MultiShotPromptBuilder):
241
+ """A prompt builder that includes a chain of thought prompt on top of the multi shot prompt."""
242
+
243
+ def chain_of_thought_prompt(self) -> str | None:
244
+ return chain_of_thought_prompt(self.task)
245
+
246
+
190
247
  prompt_builder_registry = {
191
248
  "simple_prompt_builder": SimplePromptBuilder,
192
249
  "multi_shot_prompt_builder": MultiShotPromptBuilder,
193
250
  "few_shot_prompt_builder": FewShotPromptBuilder,
194
251
  "repairs_prompt_builder": RepairsPromptBuilder,
252
+ "simple_chain_of_thought_prompt_builder": SimpleChainOfThoughtPromptBuilder,
253
+ "few_shot_chain_of_thought_prompt_builder": FewShotChainOfThoughtPromptBuilder,
254
+ "multi_shot_chain_of_thought_prompt_builder": MultiShotChainOfThoughtPromptBuilder,
195
255
  }
196
256
 
197
257
 
@@ -217,5 +277,11 @@ def prompt_builder_from_ui_name(ui_name: str) -> type[BasePromptBuilder]:
217
277
  return MultiShotPromptBuilder
218
278
  case "repairs":
219
279
  return RepairsPromptBuilder
280
+ case "simple_chain_of_thought":
281
+ return SimpleChainOfThoughtPromptBuilder
282
+ case "few_shot_chain_of_thought":
283
+ return FewShotChainOfThoughtPromptBuilder
284
+ case "multi_shot_chain_of_thought":
285
+ return MultiShotChainOfThoughtPromptBuilder
220
286
  case _:
221
287
  raise ValueError(f"Unknown prompt builder: {ui_name}")
@@ -5,6 +5,7 @@ from unittest.mock import AsyncMock, patch
5
5
  import pytest
6
6
  from pydantic import ValidationError
7
7
 
8
+ from kiln_ai.adapters.base_adapter import RunOutput
8
9
  from kiln_ai.adapters.langchain_adapters import (
9
10
  LangChainPromptAdapter,
10
11
  )
@@ -222,7 +223,9 @@ async def test_mocked_repair_task_run(sample_task, sample_task_run, sample_repai
222
223
  with patch.object(
223
224
  LangChainPromptAdapter, "_run", new_callable=AsyncMock
224
225
  ) as mock_run:
225
- mock_run.return_value = mocked_output
226
+ mock_run.return_value = RunOutput(
227
+ output=mocked_output, intermediate_outputs=None
228
+ )
226
229
 
227
230
  adapter = LangChainPromptAdapter(
228
231
  repair_task, model_name="llama_3_1_8b", provider="groq"
@@ -1,6 +1,10 @@
1
+ from unittest.mock import AsyncMock, MagicMock, patch
2
+
3
+ from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
1
4
  from langchain_groq import ChatGroq
2
5
 
3
6
  from kiln_ai.adapters.langchain_adapters import LangChainPromptAdapter
7
+ from kiln_ai.adapters.prompt_builders import SimpleChainOfThoughtPromptBuilder
4
8
  from kiln_ai.adapters.test_prompt_adaptors import build_test_task
5
9
 
6
10
 
@@ -49,3 +53,72 @@ def test_langchain_adapter_info(tmp_path):
49
53
  assert model_info.adapter_name == "kiln_langchain_adapter"
50
54
  assert model_info.model_name == "llama_3_1_8b"
51
55
  assert model_info.model_provider == "ollama"
56
+
57
+
58
+ async def test_langchain_adapter_with_cot(tmp_path):
59
+ task = build_test_task(tmp_path)
60
+ task.output_json_schema = (
61
+ '{"type": "object", "properties": {"count": {"type": "integer"}}}'
62
+ )
63
+ lca = LangChainPromptAdapter(
64
+ kiln_task=task,
65
+ model_name="llama_3_1_8b",
66
+ provider="ollama",
67
+ prompt_builder=SimpleChainOfThoughtPromptBuilder(task),
68
+ )
69
+
70
+ # Mock the base model and its invoke method
71
+ mock_base_model = MagicMock()
72
+ mock_base_model.invoke.return_value = AIMessage(
73
+ content="Chain of thought reasoning..."
74
+ )
75
+
76
+ # Create a separate mock for self.model()
77
+ mock_model_instance = MagicMock()
78
+ mock_model_instance.invoke.return_value = {"parsed": {"count": 1}}
79
+
80
+ # Mock the langchain_model_from function to return the base model
81
+ mock_model_from = AsyncMock(return_value=mock_base_model)
82
+
83
+ # Patch both the langchain_model_from function and self.model()
84
+ with (
85
+ patch(
86
+ "kiln_ai.adapters.langchain_adapters.langchain_model_from", mock_model_from
87
+ ),
88
+ patch.object(LangChainPromptAdapter, "model", return_value=mock_model_instance),
89
+ ):
90
+ response = await lca._run("test input")
91
+
92
+ # First 3 messages are the same for both calls
93
+ for invoke_args in [
94
+ mock_base_model.invoke.call_args[0][0],
95
+ mock_model_instance.invoke.call_args[0][0],
96
+ ]:
97
+ assert isinstance(
98
+ invoke_args[0], SystemMessage
99
+ ) # First message should be system prompt
100
+ assert (
101
+ "You are an assistant which performs math tasks provided in plain text."
102
+ in invoke_args[0].content
103
+ )
104
+ assert isinstance(invoke_args[1], HumanMessage)
105
+ assert "test input" in invoke_args[1].content
106
+ assert isinstance(invoke_args[2], SystemMessage)
107
+ assert "step by step" in invoke_args[2].content
108
+
109
+ # the COT should only have 3 messages
110
+ assert len(mock_base_model.invoke.call_args[0][0]) == 3
111
+ assert len(mock_model_instance.invoke.call_args[0][0]) == 5
112
+
113
+ # the final response should have the COT content and the final instructions
114
+ invoke_args = mock_model_instance.invoke.call_args[0][0]
115
+ assert isinstance(invoke_args[3], AIMessage)
116
+ assert "Chain of thought reasoning..." in invoke_args[3].content
117
+ assert isinstance(invoke_args[4], SystemMessage)
118
+ assert "Considering the above, return a final result." in invoke_args[4].content
119
+
120
+ assert (
121
+ response.intermediate_outputs["chain_of_thought"]
122
+ == "Chain of thought reasoning..."
123
+ )
124
+ assert response.output == {"count": 1}
@@ -4,9 +4,11 @@ from unittest.mock import patch
4
4
  import pytest
5
5
 
6
6
  from kiln_ai.adapters.ml_model_list import (
7
+ ModelName,
7
8
  ModelProviderName,
8
9
  OllamaConnection,
9
10
  check_provider_warnings,
11
+ get_model_and_provider,
10
12
  ollama_model_supported,
11
13
  parse_ollama_tags,
12
14
  provider_name_from_id,
@@ -123,3 +125,57 @@ def test_ollama_model_supported():
123
125
  assert ollama_model_supported(conn, "llama3.1:latest")
124
126
  assert ollama_model_supported(conn, "llama3.1")
125
127
  assert not ollama_model_supported(conn, "unknown_model")
128
+
129
+
130
+ def test_get_model_and_provider_valid():
131
+ # Test with a known valid model and provider combination
132
+ model, provider = get_model_and_provider(
133
+ ModelName.phi_3_5, ModelProviderName.ollama
134
+ )
135
+
136
+ assert model is not None
137
+ assert provider is not None
138
+ assert model.name == ModelName.phi_3_5
139
+ assert provider.name == ModelProviderName.ollama
140
+ assert provider.provider_options["model"] == "phi3.5"
141
+
142
+
143
+ def test_get_model_and_provider_invalid_model():
144
+ # Test with an invalid model name
145
+ model, provider = get_model_and_provider(
146
+ "nonexistent_model", ModelProviderName.ollama
147
+ )
148
+
149
+ assert model is None
150
+ assert provider is None
151
+
152
+
153
+ def test_get_model_and_provider_invalid_provider():
154
+ # Test with a valid model but invalid provider
155
+ model, provider = get_model_and_provider(ModelName.phi_3_5, "nonexistent_provider")
156
+
157
+ assert model is None
158
+ assert provider is None
159
+
160
+
161
+ def test_get_model_and_provider_valid_model_wrong_provider():
162
+ # Test with a valid model but a provider that doesn't support it
163
+ model, provider = get_model_and_provider(
164
+ ModelName.phi_3_5, ModelProviderName.amazon_bedrock
165
+ )
166
+
167
+ assert model is None
168
+ assert provider is None
169
+
170
+
171
+ def test_get_model_and_provider_multiple_providers():
172
+ # Test with a model that has multiple providers
173
+ model, provider = get_model_and_provider(
174
+ ModelName.llama_3_1_70b, ModelProviderName.groq
175
+ )
176
+
177
+ assert model is not None
178
+ assert provider is not None
179
+ assert model.name == ModelName.llama_3_1_70b
180
+ assert provider.name == ModelProviderName.groq
181
+ assert provider.provider_options["model"] == "llama-3.1-70b-versatile"
@@ -7,6 +7,18 @@ from langchain_core.language_models.fake_chat_models import FakeListChatModel
7
7
  import kiln_ai.datamodel as datamodel
8
8
  from kiln_ai.adapters.langchain_adapters import LangChainPromptAdapter
9
9
  from kiln_ai.adapters.ml_model_list import built_in_models, ollama_online
10
+ from kiln_ai.adapters.prompt_builders import (
11
+ BasePromptBuilder,
12
+ SimpleChainOfThoughtPromptBuilder,
13
+ )
14
+
15
+
16
+ def get_all_models_and_providers():
17
+ model_provider_pairs = []
18
+ for model in built_in_models:
19
+ for provider in model.providers:
20
+ model_provider_pairs.append((model.name, provider.name))
21
+ return model_provider_pairs
10
22
 
11
23
 
12
24
  @pytest.mark.paid
@@ -28,6 +40,9 @@ async def test_groq(tmp_path):
28
40
  "llama_3_2_3b",
29
41
  "llama_3_2_11b",
30
42
  "llama_3_2_90b",
43
+ "claude_3_5_haiku",
44
+ "claude_3_5_sonnet",
45
+ "phi_3_5",
31
46
  ],
32
47
  )
33
48
  @pytest.mark.paid
@@ -117,15 +132,19 @@ async def test_mock_returning_run(tmp_path):
117
132
 
118
133
  @pytest.mark.paid
119
134
  @pytest.mark.ollama
120
- async def test_all_built_in_models(tmp_path):
135
+ @pytest.mark.parametrize("model_name,provider_name", get_all_models_and_providers())
136
+ async def test_all_models_providers_plaintext(tmp_path, model_name, provider_name):
121
137
  task = build_test_task(tmp_path)
122
- for model in built_in_models:
123
- for provider in model.providers:
124
- try:
125
- print(f"Running {model.name} {provider.name}")
126
- await run_simple_task(task, model.name, provider.name)
127
- except Exception as e:
128
- raise RuntimeError(f"Error running {model.name} {provider}") from e
138
+ await run_simple_task(task, model_name, provider_name)
139
+
140
+
141
+ @pytest.mark.paid
142
+ @pytest.mark.ollama
143
+ @pytest.mark.parametrize("model_name,provider_name", get_all_models_and_providers())
144
+ async def test_cot_prompt_builder(tmp_path, model_name, provider_name):
145
+ task = build_test_task(tmp_path)
146
+ pb = SimpleChainOfThoughtPromptBuilder(task)
147
+ await run_simple_task(task, model_name, provider_name, pb)
129
148
 
130
149
 
131
150
  def build_test_task(tmp_path: Path):
@@ -157,13 +176,25 @@ def build_test_task(tmp_path: Path):
157
176
  return task
158
177
 
159
178
 
160
- async def run_simple_test(tmp_path: Path, model_name: str, provider: str | None = None):
179
+ async def run_simple_test(
180
+ tmp_path: Path,
181
+ model_name: str,
182
+ provider: str | None = None,
183
+ prompt_builder: BasePromptBuilder | None = None,
184
+ ):
161
185
  task = build_test_task(tmp_path)
162
- return await run_simple_task(task, model_name, provider)
186
+ return await run_simple_task(task, model_name, provider, prompt_builder)
163
187
 
164
188
 
165
- async def run_simple_task(task: datamodel.Task, model_name: str, provider: str):
166
- adapter = LangChainPromptAdapter(task, model_name=model_name, provider=provider)
189
+ async def run_simple_task(
190
+ task: datamodel.Task,
191
+ model_name: str,
192
+ provider: str,
193
+ prompt_builder: BasePromptBuilder | None = None,
194
+ ) -> datamodel.TaskRun:
195
+ adapter = LangChainPromptAdapter(
196
+ task, model_name=model_name, provider=provider, prompt_builder=prompt_builder
197
+ )
167
198
 
168
199
  run = await adapter.invoke(
169
200
  "You should answer the following question: four plus six times 10"
@@ -174,9 +205,14 @@ async def run_simple_task(task: datamodel.Task, model_name: str, provider: str):
174
205
  run.input == "You should answer the following question: four plus six times 10"
175
206
  )
176
207
  assert "64" in run.output.output
177
- assert run.output.source.properties == {
178
- "adapter_name": "kiln_langchain_adapter",
179
- "model_name": model_name,
180
- "model_provider": provider,
181
- "prompt_builder_name": "simple_prompt_builder",
182
- }
208
+ source_props = run.output.source.properties
209
+ assert source_props["adapter_name"] == "kiln_langchain_adapter"
210
+ assert source_props["model_name"] == model_name
211
+ assert source_props["model_provider"] == provider
212
+ expected_prompt_builder_name = (
213
+ prompt_builder.__class__.prompt_builder_name()
214
+ if prompt_builder
215
+ else "simple_prompt_builder"
216
+ )
217
+ assert source_props["prompt_builder_name"] == expected_prompt_builder_name
218
+ return run