kiln-ai 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (42) hide show
  1. kiln_ai/adapters/__init__.py +11 -1
  2. kiln_ai/adapters/adapter_registry.py +19 -0
  3. kiln_ai/adapters/data_gen/__init__.py +11 -0
  4. kiln_ai/adapters/data_gen/data_gen_task.py +69 -1
  5. kiln_ai/adapters/data_gen/test_data_gen_task.py +30 -21
  6. kiln_ai/adapters/fine_tune/__init__.py +14 -0
  7. kiln_ai/adapters/fine_tune/base_finetune.py +186 -0
  8. kiln_ai/adapters/fine_tune/dataset_formatter.py +187 -0
  9. kiln_ai/adapters/fine_tune/finetune_registry.py +11 -0
  10. kiln_ai/adapters/fine_tune/fireworks_finetune.py +308 -0
  11. kiln_ai/adapters/fine_tune/openai_finetune.py +205 -0
  12. kiln_ai/adapters/fine_tune/test_base_finetune.py +290 -0
  13. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +342 -0
  14. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +455 -0
  15. kiln_ai/adapters/fine_tune/test_openai_finetune.py +503 -0
  16. kiln_ai/adapters/langchain_adapters.py +103 -13
  17. kiln_ai/adapters/ml_model_list.py +218 -304
  18. kiln_ai/adapters/ollama_tools.py +114 -0
  19. kiln_ai/adapters/provider_tools.py +295 -0
  20. kiln_ai/adapters/repair/test_repair_task.py +6 -11
  21. kiln_ai/adapters/test_langchain_adapter.py +46 -18
  22. kiln_ai/adapters/test_ollama_tools.py +42 -0
  23. kiln_ai/adapters/test_prompt_adaptors.py +7 -5
  24. kiln_ai/adapters/test_provider_tools.py +312 -0
  25. kiln_ai/adapters/test_structured_output.py +22 -43
  26. kiln_ai/datamodel/__init__.py +235 -22
  27. kiln_ai/datamodel/basemodel.py +30 -0
  28. kiln_ai/datamodel/registry.py +31 -0
  29. kiln_ai/datamodel/test_basemodel.py +29 -1
  30. kiln_ai/datamodel/test_dataset_split.py +234 -0
  31. kiln_ai/datamodel/test_example_models.py +12 -0
  32. kiln_ai/datamodel/test_models.py +91 -1
  33. kiln_ai/datamodel/test_registry.py +96 -0
  34. kiln_ai/utils/config.py +9 -0
  35. kiln_ai/utils/name_generator.py +125 -0
  36. kiln_ai/utils/test_name_geneator.py +47 -0
  37. {kiln_ai-0.6.0.dist-info → kiln_ai-0.7.0.dist-info}/METADATA +4 -2
  38. kiln_ai-0.7.0.dist-info/RECORD +56 -0
  39. kiln_ai/adapters/test_ml_model_list.py +0 -181
  40. kiln_ai-0.6.0.dist-info/RECORD +0 -36
  41. {kiln_ai-0.6.0.dist-info → kiln_ai-0.7.0.dist-info}/WHEEL +0 -0
  42. {kiln_ai-0.6.0.dist-info → kiln_ai-0.7.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -0,0 +1,503 @@
1
+ import time
2
+ from pathlib import Path
3
+ from unittest import mock
4
+ from unittest.mock import MagicMock, patch
5
+
6
+ import openai
7
+ import pytest
8
+ from openai.types.fine_tuning import FineTuningJob
9
+
10
+ from kiln_ai.adapters.fine_tune.base_finetune import FineTuneStatusType
11
+ from kiln_ai.adapters.fine_tune.dataset_formatter import DatasetFormat, DatasetFormatter
12
+ from kiln_ai.adapters.fine_tune.openai_finetune import OpenAIFinetune
13
+ from kiln_ai.datamodel import DatasetSplit, Task, Train80Test20SplitDefinition
14
+ from kiln_ai.datamodel import Finetune as FinetuneModel
15
+ from kiln_ai.utils.config import Config
16
+
17
+
18
+ @pytest.fixture
19
+ def openai_finetune(tmp_path):
20
+ tmp_file = tmp_path / "test-finetune.kiln"
21
+ finetune = OpenAIFinetune(
22
+ datamodel=FinetuneModel(
23
+ name="test-finetune",
24
+ provider="openai",
25
+ provider_id="openai-123",
26
+ base_model_id="gpt-4o",
27
+ train_split_name="train",
28
+ dataset_split_id="dataset-123",
29
+ system_message="Test system message",
30
+ fine_tune_model_id="ft-123",
31
+ path=tmp_file,
32
+ ),
33
+ )
34
+ return finetune
35
+
36
+
37
+ @pytest.fixture
38
+ def mock_response():
39
+ response = MagicMock(spec=FineTuningJob)
40
+ response.error = None
41
+ response.status = "succeeded"
42
+ response.finished_at = time.time()
43
+ response.estimated_finish = None
44
+ response.fine_tuned_model = "ft-123"
45
+ response.model = "gpt-4o"
46
+ return response
47
+
48
+
49
+ @pytest.fixture
50
+ def mock_dataset():
51
+ return DatasetSplit(
52
+ id="test-dataset-123",
53
+ name="Test Dataset",
54
+ splits=Train80Test20SplitDefinition,
55
+ split_contents={"train": [], "test": []},
56
+ )
57
+
58
+
59
+ @pytest.fixture
60
+ def mock_task():
61
+ return Task(
62
+ id="test-task-123",
63
+ name="Test Task",
64
+ output_json_schema=None, # Can be modified in specific tests
65
+ instruction="Test instruction",
66
+ )
67
+
68
+
69
+ async def test_setup(openai_finetune):
70
+ if not Config.shared().open_ai_api_key:
71
+ pytest.skip("OpenAI API key not set")
72
+ openai_finetune.provider_id = "openai-123"
73
+ openai_finetune.provider = "openai"
74
+
75
+ # Real API call, with fake ID
76
+ status = await openai_finetune.status()
77
+ # fake id fails
78
+ assert status.status == FineTuneStatusType.unknown
79
+ assert "Job with this ID not found. It may have been deleted." == status.message
80
+
81
+
82
+ @pytest.mark.parametrize(
83
+ "exception,expected_status,expected_message",
84
+ [
85
+ (
86
+ openai.APIConnectionError(request=MagicMock()),
87
+ FineTuneStatusType.unknown,
88
+ "Server connection error",
89
+ ),
90
+ (
91
+ openai.RateLimitError(
92
+ message="Rate limit exceeded", body={}, response=MagicMock()
93
+ ),
94
+ FineTuneStatusType.unknown,
95
+ "Rate limit exceeded",
96
+ ),
97
+ (
98
+ openai.APIStatusError(
99
+ "Not found",
100
+ response=MagicMock(status_code=404),
101
+ body={},
102
+ ),
103
+ FineTuneStatusType.unknown,
104
+ "Job with this ID not found",
105
+ ),
106
+ (
107
+ openai.APIStatusError(
108
+ "Server error",
109
+ response=MagicMock(status_code=500),
110
+ body={},
111
+ ),
112
+ FineTuneStatusType.unknown,
113
+ "Unknown error",
114
+ ),
115
+ ],
116
+ )
117
+ async def test_status_api_errors(
118
+ openai_finetune, exception, expected_status, expected_message
119
+ ):
120
+ with patch(
121
+ "kiln_ai.adapters.fine_tune.openai_finetune.oai_client.fine_tuning.jobs.retrieve",
122
+ side_effect=exception,
123
+ ):
124
+ status = await openai_finetune.status()
125
+ assert status.status == expected_status
126
+ assert expected_message in status.message
127
+
128
+
129
+ @pytest.mark.parametrize(
130
+ "job_status,expected_status,message_contains",
131
+ [
132
+ ("failed", FineTuneStatusType.failed, "Job failed"),
133
+ ("cancelled", FineTuneStatusType.failed, "Job cancelled"),
134
+ ("succeeded", FineTuneStatusType.completed, "Training job completed"),
135
+ ("running", FineTuneStatusType.running, "Fine tune job is running"),
136
+ ("queued", FineTuneStatusType.running, "Fine tune job is running"),
137
+ (
138
+ "validating_files",
139
+ FineTuneStatusType.running,
140
+ "Fine tune job is running",
141
+ ),
142
+ ("unknown_status", FineTuneStatusType.unknown, "Unknown status"),
143
+ ],
144
+ )
145
+ async def test_status_job_states(
146
+ openai_finetune,
147
+ mock_response,
148
+ job_status,
149
+ expected_status,
150
+ message_contains,
151
+ ):
152
+ mock_response.status = job_status
153
+
154
+ with patch(
155
+ "kiln_ai.adapters.fine_tune.openai_finetune.oai_client.fine_tuning.jobs.retrieve",
156
+ return_value=mock_response,
157
+ ):
158
+ status = await openai_finetune.status()
159
+ assert status.status == expected_status
160
+ assert message_contains in status.message
161
+
162
+
163
+ async def test_status_with_error_response(openai_finetune, mock_response):
164
+ mock_response.error = MagicMock()
165
+ mock_response.error.message = "Something went wrong"
166
+
167
+ with patch(
168
+ "kiln_ai.adapters.fine_tune.openai_finetune.oai_client.fine_tuning.jobs.retrieve",
169
+ return_value=mock_response,
170
+ ):
171
+ status = await openai_finetune.status()
172
+ assert status.status == FineTuneStatusType.failed
173
+ assert status.message.startswith("Something went wrong [Code:")
174
+
175
+
176
+ async def test_status_with_estimated_finish_time(openai_finetune, mock_response):
177
+ current_time = time.time()
178
+ mock_response.status = "running"
179
+ mock_response.estimated_finish = current_time + 300 # 5 minutes from now
180
+
181
+ with patch(
182
+ "kiln_ai.adapters.fine_tune.openai_finetune.oai_client.fine_tuning.jobs.retrieve",
183
+ return_value=mock_response,
184
+ ):
185
+ status = await openai_finetune.status()
186
+ assert status.status == FineTuneStatusType.running
187
+ assert (
188
+ "Estimated finish time: 299 seconds" in status.message
189
+ ) # non zero time passes
190
+
191
+
192
+ async def test_status_empty_response(openai_finetune):
193
+ with patch(
194
+ "kiln_ai.adapters.fine_tune.openai_finetune.oai_client.fine_tuning.jobs.retrieve",
195
+ return_value=mock_response,
196
+ ):
197
+ status = await openai_finetune.status()
198
+ assert status.status == FineTuneStatusType.unknown
199
+ assert "Invalid response from OpenAI" in status.message
200
+
201
+
202
+ async def test_generate_and_upload_jsonl_success(
203
+ openai_finetune, mock_dataset, mock_task
204
+ ):
205
+ mock_path = Path("mock_path.jsonl")
206
+ mock_file_id = "file-123"
207
+
208
+ # Mock the formatter
209
+ mock_formatter = MagicMock(spec=DatasetFormatter)
210
+ mock_formatter.dump_to_file.return_value = mock_path
211
+
212
+ # Mock the file response
213
+ mock_file_response = MagicMock()
214
+ mock_file_response.id = mock_file_id
215
+
216
+ with (
217
+ patch(
218
+ "kiln_ai.adapters.fine_tune.openai_finetune.DatasetFormatter",
219
+ return_value=mock_formatter,
220
+ ) as mock_formatter_class,
221
+ patch(
222
+ "kiln_ai.adapters.fine_tune.openai_finetune.oai_client.files.create",
223
+ return_value=mock_file_response,
224
+ ) as mock_create,
225
+ patch("builtins.open") as mock_open,
226
+ ):
227
+ result = await openai_finetune.generate_and_upload_jsonl(
228
+ mock_dataset, "train", mock_task
229
+ )
230
+
231
+ # Verify formatter was created with correct parameters
232
+ mock_formatter_class.assert_called_once_with(
233
+ mock_dataset, openai_finetune.datamodel.system_message
234
+ )
235
+
236
+ # Verify correct format was used
237
+ mock_formatter.dump_to_file.assert_called_once_with(
238
+ "train", DatasetFormat.OPENAI_CHAT_JSONL
239
+ )
240
+
241
+ # Verify file was opened and uploaded
242
+ mock_open.assert_called_once_with(mock_path, "rb")
243
+ mock_create.assert_called_once()
244
+
245
+ assert result == mock_file_id
246
+
247
+
248
+ async def test_generate_and_upload_jsonl_toolcall_success(
249
+ openai_finetune, mock_dataset, mock_task
250
+ ):
251
+ mock_path = Path("mock_path.jsonl")
252
+ mock_file_id = "file-123"
253
+ mock_task.output_json_schema = '{"type": "object", "properties": {"key": {"type": "string"}}}' # Add JSON schema
254
+
255
+ # Mock the formatter
256
+ mock_formatter = MagicMock(spec=DatasetFormatter)
257
+ mock_formatter.dump_to_file.return_value = mock_path
258
+
259
+ # Mock the file response
260
+ mock_file_response = MagicMock()
261
+ mock_file_response.id = mock_file_id
262
+
263
+ with (
264
+ patch(
265
+ "kiln_ai.adapters.fine_tune.openai_finetune.DatasetFormatter",
266
+ return_value=mock_formatter,
267
+ ) as mock_formatter_class,
268
+ patch(
269
+ "kiln_ai.adapters.fine_tune.openai_finetune.oai_client.files.create",
270
+ return_value=mock_file_response,
271
+ ) as mock_create,
272
+ patch("builtins.open") as mock_open,
273
+ ):
274
+ result = await openai_finetune.generate_and_upload_jsonl(
275
+ mock_dataset, "train", mock_task
276
+ )
277
+
278
+ # Verify formatter was created with correct parameters
279
+ mock_formatter_class.assert_called_once_with(
280
+ mock_dataset, openai_finetune.datamodel.system_message
281
+ )
282
+
283
+ # Verify correct format was used
284
+ mock_formatter.dump_to_file.assert_called_once_with(
285
+ "train", DatasetFormat.OPENAI_CHAT_TOOLCALL_JSONL
286
+ )
287
+
288
+ # Verify file was opened and uploaded
289
+ mock_open.assert_called_once_with(mock_path, "rb")
290
+ mock_create.assert_called_once()
291
+
292
+ assert result == mock_file_id
293
+
294
+
295
+ async def test_generate_and_upload_jsonl_upload_failure(
296
+ openai_finetune, mock_dataset, mock_task
297
+ ):
298
+ mock_path = Path("mock_path.jsonl")
299
+
300
+ mock_formatter = MagicMock(spec=DatasetFormatter)
301
+ mock_formatter.dump_to_file.return_value = mock_path
302
+
303
+ # Mock response with no ID
304
+ mock_file_response = MagicMock()
305
+ mock_file_response.id = None
306
+
307
+ with (
308
+ patch(
309
+ "kiln_ai.adapters.fine_tune.openai_finetune.DatasetFormatter",
310
+ return_value=mock_formatter,
311
+ ),
312
+ patch(
313
+ "kiln_ai.adapters.fine_tune.openai_finetune.oai_client.files.create",
314
+ return_value=mock_file_response,
315
+ ),
316
+ patch("builtins.open"),
317
+ ):
318
+ with pytest.raises(ValueError, match="Failed to upload file to OpenAI"):
319
+ await openai_finetune.generate_and_upload_jsonl(
320
+ mock_dataset, "train", mock_task
321
+ )
322
+
323
+
324
+ async def test_generate_and_upload_jsonl_api_error(
325
+ openai_finetune, mock_dataset, mock_task
326
+ ):
327
+ mock_path = Path("mock_path.jsonl")
328
+
329
+ mock_formatter = MagicMock(spec=DatasetFormatter)
330
+ mock_formatter.dump_to_file.return_value = mock_path
331
+
332
+ with (
333
+ patch(
334
+ "kiln_ai.adapters.fine_tune.openai_finetune.DatasetFormatter",
335
+ return_value=mock_formatter,
336
+ ),
337
+ patch(
338
+ "kiln_ai.adapters.fine_tune.openai_finetune.oai_client.files.create",
339
+ side_effect=openai.APIError(
340
+ message="API error", request=MagicMock(), body={}
341
+ ),
342
+ ),
343
+ patch("builtins.open"),
344
+ ):
345
+ with pytest.raises(openai.APIError):
346
+ await openai_finetune.generate_and_upload_jsonl(
347
+ mock_dataset, "train", mock_task
348
+ )
349
+
350
+
351
+ async def test_start_success(openai_finetune, mock_dataset, mock_task):
352
+ openai_finetune.datamodel.parent = mock_task
353
+
354
+ # Mock parameters
355
+ openai_finetune.datamodel.parameters = {
356
+ "n_epochs": 3,
357
+ "learning_rate_multiplier": 0.1,
358
+ "batch_size": 4,
359
+ "ignored_param": "value",
360
+ }
361
+
362
+ # Mock the fine-tuning response
363
+ mock_ft_response = MagicMock()
364
+ mock_ft_response.id = "ft-123"
365
+ mock_ft_response.fine_tuned_model = None
366
+ mock_ft_response.model = "gpt-4o-mini-2024-07-18"
367
+
368
+ with (
369
+ patch.object(
370
+ openai_finetune,
371
+ "generate_and_upload_jsonl",
372
+ side_effect=["train-file-123", "val-file-123"],
373
+ ) as mock_upload,
374
+ patch(
375
+ "kiln_ai.adapters.fine_tune.openai_finetune.oai_client.fine_tuning.jobs.create",
376
+ return_value=mock_ft_response,
377
+ ) as mock_create,
378
+ ):
379
+ await openai_finetune._start(mock_dataset)
380
+
381
+ # Verify file uploads
382
+ assert mock_upload.call_count == 1 # Only training file
383
+ mock_upload.assert_called_with(
384
+ mock_dataset, openai_finetune.datamodel.train_split_name, mock_task
385
+ )
386
+
387
+ # Verify fine-tune creation
388
+ mock_create.assert_called_once_with(
389
+ training_file="train-file-123",
390
+ model="gpt-4o",
391
+ validation_file=None,
392
+ seed=None,
393
+ hyperparameters={
394
+ "n_epochs": 3,
395
+ "learning_rate_multiplier": 0.1,
396
+ "batch_size": 4,
397
+ },
398
+ suffix=f"kiln_ai.{openai_finetune.datamodel.id}",
399
+ )
400
+
401
+ # Verify model updates
402
+ assert openai_finetune.datamodel.provider_id == "ft-123"
403
+ assert openai_finetune.datamodel.base_model_id == "gpt-4o-mini-2024-07-18"
404
+
405
+
406
+ async def test_start_with_validation(openai_finetune, mock_dataset, mock_task):
407
+ openai_finetune.datamodel.parent = mock_task
408
+ openai_finetune.datamodel.validation_split_name = "validation"
409
+
410
+ mock_ft_response = MagicMock()
411
+ mock_ft_response.id = "ft-123"
412
+ mock_ft_response.fine_tuned_model = None
413
+ mock_ft_response.model = "gpt-4o-mini-2024-07-18"
414
+
415
+ with (
416
+ patch.object(
417
+ openai_finetune,
418
+ "generate_and_upload_jsonl",
419
+ side_effect=["train-file-123", "val-file-123"],
420
+ ) as mock_upload,
421
+ patch(
422
+ "kiln_ai.adapters.fine_tune.openai_finetune.oai_client.fine_tuning.jobs.create",
423
+ return_value=mock_ft_response,
424
+ ) as mock_create,
425
+ ):
426
+ await openai_finetune._start(mock_dataset)
427
+
428
+ # Verify both files were uploaded
429
+ assert mock_upload.call_count == 2
430
+ mock_upload.assert_has_calls(
431
+ [
432
+ mock.call(
433
+ mock_dataset, openai_finetune.datamodel.train_split_name, mock_task
434
+ ),
435
+ mock.call(mock_dataset, "validation", mock_task),
436
+ ]
437
+ )
438
+
439
+ # Verify validation file was included
440
+ mock_create.assert_called_once()
441
+ assert mock_create.call_args[1]["validation_file"] == "val-file-123"
442
+
443
+
444
+ async def test_start_no_task(openai_finetune, mock_dataset):
445
+ openai_finetune.datamodel.parent = None
446
+ openai_finetune.datamodel.path = None
447
+
448
+ with pytest.raises(ValueError, match="Task is required to start a fine-tune"):
449
+ await openai_finetune._start(mock_dataset)
450
+
451
+
452
+ async def test_status_updates_model_ids(openai_finetune, mock_response):
453
+ # Set up initial model IDs
454
+ openai_finetune.datamodel.fine_tune_model_id = "old-ft-model"
455
+ openai_finetune.datamodel.base_model_id = "old-base-model"
456
+
457
+ # Configure mock response with different model IDs
458
+ mock_response.fine_tuned_model = "new-ft-model"
459
+ mock_response.model = "new-base-model"
460
+ mock_response.status = "succeeded"
461
+
462
+ with (
463
+ patch(
464
+ "kiln_ai.adapters.fine_tune.openai_finetune.oai_client.fine_tuning.jobs.retrieve",
465
+ return_value=mock_response,
466
+ ),
467
+ ):
468
+ status = await openai_finetune.status()
469
+
470
+ # Verify model IDs were updated
471
+ assert openai_finetune.datamodel.fine_tune_model_id == "new-ft-model"
472
+ assert openai_finetune.datamodel.base_model_id == "new-base-model"
473
+
474
+ # Verify save was called
475
+ # This isn't properly mocked, so not checking
476
+ # assert openai_finetune.datamodel.save.called
477
+
478
+ # Verify status is still returned correctly
479
+ assert status.status == FineTuneStatusType.completed
480
+ assert status.message == "Training job completed"
481
+
482
+
483
+ async def test_status_updates_latest_status(openai_finetune, mock_response):
484
+ # Set initial status
485
+ openai_finetune.datamodel.latest_status = FineTuneStatusType.running
486
+ assert openai_finetune.datamodel.latest_status == FineTuneStatusType.running
487
+ mock_response.status = "succeeded"
488
+
489
+ with (
490
+ patch(
491
+ "kiln_ai.adapters.fine_tune.openai_finetune.oai_client.fine_tuning.jobs.retrieve",
492
+ return_value=mock_response,
493
+ ),
494
+ ):
495
+ status = await openai_finetune.status()
496
+
497
+ # Verify status was updated in datamodel
498
+ assert openai_finetune.datamodel.latest_status == FineTuneStatusType.completed
499
+ assert status.status == FineTuneStatusType.completed
500
+ assert status.message == "Training job completed"
501
+
502
+ # Verify file was saved
503
+ assert openai_finetune.datamodel.path.exists()
@@ -1,21 +1,35 @@
1
- from typing import Dict
1
+ import os
2
+ from os import getenv
3
+ from typing import Any, Dict
2
4
 
5
+ from langchain_aws import ChatBedrockConverse
3
6
  from langchain_core.language_models import LanguageModelInput
4
7
  from langchain_core.language_models.chat_models import BaseChatModel
5
8
  from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
6
9
  from langchain_core.messages.base import BaseMessage
7
10
  from langchain_core.runnables import Runnable
11
+ from langchain_fireworks import ChatFireworks
12
+ from langchain_groq import ChatGroq
13
+ from langchain_ollama import ChatOllama
14
+ from langchain_openai import ChatOpenAI
8
15
  from pydantic import BaseModel
9
16
 
10
17
  import kiln_ai.datamodel as datamodel
18
+ from kiln_ai.adapters.ollama_tools import (
19
+ get_ollama_connection,
20
+ ollama_base_url,
21
+ ollama_model_installed,
22
+ )
23
+ from kiln_ai.utils.config import Config
11
24
 
12
25
  from .base_adapter import AdapterInfo, BaseAdapter, BasePromptBuilder, RunOutput
13
- from .ml_model_list import langchain_model_from
26
+ from .ml_model_list import KilnModelProvider, ModelProviderName
27
+ from .provider_tools import kiln_model_provider_from
14
28
 
15
29
  LangChainModelType = BaseChatModel | Runnable[LanguageModelInput, Dict | BaseModel]
16
30
 
17
31
 
18
- class LangChainPromptAdapter(BaseAdapter):
32
+ class LangchainAdapter(BaseAdapter):
19
33
  _model: LangChainModelType | None = None
20
34
 
21
35
  def __init__(
@@ -51,12 +65,6 @@ class LangChainPromptAdapter(BaseAdapter):
51
65
  "model_name and provider must be provided if custom_model is not provided"
52
66
  )
53
67
 
54
- def adapter_specific_instructions(self) -> str | None:
55
- # TODO: would be better to explicitly use bind_tools:tool_choice="task_response" here
56
- if self.has_structured_output():
57
- return "Always respond with a tool call. Never respond with a human readable message."
58
- return None
59
-
60
68
  async def model(self) -> LangChainModelType:
61
69
  # cached model
62
70
  if self._model:
@@ -79,8 +87,13 @@ class LangChainPromptAdapter(BaseAdapter):
79
87
  )
80
88
  output_schema["title"] = "task_response"
81
89
  output_schema["description"] = "A response from the task"
90
+ with_structured_output_options = await get_structured_output_options(
91
+ self.model_name, self.model_provider
92
+ )
82
93
  self._model = self._model.with_structured_output(
83
- output_schema, include_raw=True
94
+ output_schema,
95
+ include_raw=True,
96
+ **with_structured_output_options,
84
97
  )
85
98
  return self._model
86
99
 
@@ -108,17 +121,16 @@ class LangChainPromptAdapter(BaseAdapter):
108
121
  )
109
122
 
110
123
  cot_messages = [*messages]
111
- cot_response = base_model.invoke(cot_messages)
124
+ cot_response = await base_model.ainvoke(cot_messages)
112
125
  intermediate_outputs["chain_of_thought"] = cot_response.content
113
126
  messages.append(AIMessage(content=cot_response.content))
114
127
  messages.append(
115
128
  SystemMessage(content="Considering the above, return a final result.")
116
129
  )
117
130
  elif cot_prompt:
118
- # for plaintext output, we just add COT instructions. We still only make one call.
119
131
  messages.append(SystemMessage(content=cot_prompt))
120
132
 
121
- response = chain.invoke(messages)
133
+ response = await chain.ainvoke(messages)
122
134
 
123
135
  if self.has_structured_output():
124
136
  if (
@@ -160,3 +172,81 @@ class LangChainPromptAdapter(BaseAdapter):
160
172
  ):
161
173
  return response["arguments"]
162
174
  return response
175
+
176
+
177
+ async def get_structured_output_options(
178
+ model_name: str, model_provider: str
179
+ ) -> Dict[str, Any]:
180
+ finetune_provider = await kiln_model_provider_from(model_name, model_provider)
181
+ if finetune_provider and finetune_provider.adapter_options.get("langchain"):
182
+ return finetune_provider.adapter_options["langchain"].get(
183
+ "with_structured_output_options", {}
184
+ )
185
+ return {}
186
+
187
+
188
+ async def langchain_model_from(
189
+ name: str, provider_name: str | None = None
190
+ ) -> BaseChatModel:
191
+ provider = await kiln_model_provider_from(name, provider_name)
192
+ return await langchain_model_from_provider(provider, name)
193
+
194
+
195
+ async def langchain_model_from_provider(
196
+ provider: KilnModelProvider, model_name: str
197
+ ) -> BaseChatModel:
198
+ if provider.name == ModelProviderName.openai:
199
+ api_key = Config.shared().open_ai_api_key
200
+ return ChatOpenAI(**provider.provider_options, openai_api_key=api_key) # type: ignore[arg-type]
201
+ elif provider.name == ModelProviderName.groq:
202
+ api_key = Config.shared().groq_api_key
203
+ if api_key is None:
204
+ raise ValueError(
205
+ "Attempted to use Groq without an API key set. "
206
+ "Get your API key from https://console.groq.com/keys"
207
+ )
208
+ return ChatGroq(**provider.provider_options, groq_api_key=api_key) # type: ignore[arg-type]
209
+ elif provider.name == ModelProviderName.amazon_bedrock:
210
+ api_key = Config.shared().bedrock_access_key
211
+ secret_key = Config.shared().bedrock_secret_key
212
+ # langchain doesn't allow passing these, so ugly hack to set env vars
213
+ os.environ["AWS_ACCESS_KEY_ID"] = api_key
214
+ os.environ["AWS_SECRET_ACCESS_KEY"] = secret_key
215
+ return ChatBedrockConverse(
216
+ **provider.provider_options,
217
+ )
218
+ elif provider.name == ModelProviderName.fireworks_ai:
219
+ api_key = Config.shared().fireworks_api_key
220
+ return ChatFireworks(**provider.provider_options, api_key=api_key)
221
+ elif provider.name == ModelProviderName.ollama:
222
+ # Ollama model naming is pretty flexible. We try a few versions of the model name
223
+ potential_model_names = []
224
+ if "model" in provider.provider_options:
225
+ potential_model_names.append(provider.provider_options["model"])
226
+ if "model_aliases" in provider.provider_options:
227
+ potential_model_names.extend(provider.provider_options["model_aliases"])
228
+
229
+ # Get the list of models Ollama supports
230
+ ollama_connection = await get_ollama_connection()
231
+ if ollama_connection is None:
232
+ raise ValueError("Failed to connect to Ollama. Ensure Ollama is running.")
233
+
234
+ for model_name in potential_model_names:
235
+ if ollama_model_installed(ollama_connection, model_name):
236
+ return ChatOllama(model=model_name, base_url=ollama_base_url())
237
+
238
+ raise ValueError(f"Model {model_name} not installed on Ollama")
239
+ elif provider.name == ModelProviderName.openrouter:
240
+ api_key = Config.shared().open_router_api_key
241
+ base_url = getenv("OPENROUTER_BASE_URL") or "https://openrouter.ai/api/v1"
242
+ return ChatOpenAI(
243
+ **provider.provider_options,
244
+ openai_api_key=api_key, # type: ignore[arg-type]
245
+ openai_api_base=base_url, # type: ignore[arg-type]
246
+ default_headers={
247
+ "HTTP-Referer": "https://getkiln.ai/openrouter",
248
+ "X-Title": "KilnAI",
249
+ },
250
+ )
251
+ else:
252
+ raise ValueError(f"Invalid model or provider: {model_name} - {provider.name}")