kiln-ai 0.5.2__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

@@ -1,16 +1,23 @@
1
1
  from typing import Dict
2
2
 
3
+ from langchain_core.language_models import LanguageModelInput
3
4
  from langchain_core.language_models.chat_models import BaseChatModel
4
5
  from langchain_core.messages import HumanMessage, SystemMessage
5
6
  from langchain_core.messages.base import BaseMessage
7
+ from langchain_core.runnables import Runnable
8
+ from pydantic import BaseModel
6
9
 
7
10
  import kiln_ai.datamodel as datamodel
8
11
 
9
12
  from .base_adapter import AdapterInfo, BaseAdapter, BasePromptBuilder
10
13
  from .ml_model_list import langchain_model_from
11
14
 
15
+ LangChainModelType = BaseChatModel | Runnable[LanguageModelInput, Dict | BaseModel]
16
+
12
17
 
13
18
  class LangChainPromptAdapter(BaseAdapter):
19
+ _model: LangChainModelType | None = None
20
+
14
21
  def __init__(
15
22
  self,
16
23
  kiln_task: datamodel.Task,
@@ -21,7 +28,7 @@ class LangChainPromptAdapter(BaseAdapter):
21
28
  ):
22
29
  super().__init__(kiln_task, prompt_builder=prompt_builder)
23
30
  if custom_model is not None:
24
- self.model = custom_model
31
+ self._model = custom_model
25
32
 
26
33
  # Attempt to infer model provider and name from custom model
27
34
  self.model_provider = "custom.langchain:" + custom_model.__class__.__name__
@@ -37,19 +44,32 @@ class LangChainPromptAdapter(BaseAdapter):
37
44
  ):
38
45
  self.model_name = "custom.langchain:" + getattr(custom_model, "model")
39
46
  elif model_name is not None:
40
- self.model = langchain_model_from(model_name, provider)
41
47
  self.model_name = model_name
42
48
  self.model_provider = provider or "custom.langchain.default_provider"
43
49
  else:
44
50
  raise ValueError(
45
51
  "model_name and provider must be provided if custom_model is not provided"
46
52
  )
53
+
54
+ def adapter_specific_instructions(self) -> str | None:
55
+ # TODO: would be better to explicitly use bind_tools:tool_choice="task_response" here
47
56
  if self.has_structured_output():
48
- if not hasattr(self.model, "with_structured_output") or not callable(
49
- getattr(self.model, "with_structured_output")
57
+ return "Always respond with a tool call. Never respond with a human readable message."
58
+ return None
59
+
60
+ async def model(self) -> LangChainModelType:
61
+ # cached model
62
+ if self._model:
63
+ return self._model
64
+
65
+ self._model = await langchain_model_from(self.model_name, self.model_provider)
66
+
67
+ if self.has_structured_output():
68
+ if not hasattr(self._model, "with_structured_output") or not callable(
69
+ getattr(self._model, "with_structured_output")
50
70
  ):
51
71
  raise ValueError(
52
- f"model {self.model} does not support structured output, cannot use output_json_schema"
72
+ f"model {self._model} does not support structured output, cannot use output_json_schema"
53
73
  )
54
74
  # Langchain expects title/description to be at top level, on top of json schema
55
75
  output_schema = self.kiln_task.output_schema()
@@ -59,15 +79,10 @@ class LangChainPromptAdapter(BaseAdapter):
59
79
  )
60
80
  output_schema["title"] = "task_response"
61
81
  output_schema["description"] = "A response from the task"
62
- self.model = self.model.with_structured_output(
82
+ self._model = self._model.with_structured_output(
63
83
  output_schema, include_raw=True
64
84
  )
65
-
66
- def adapter_specific_instructions(self) -> str | None:
67
- # TODO: would be better to explicitly use bind_tools:tool_choice="task_response" here
68
- if self.has_structured_output():
69
- return "Always respond with a tool call. Never respond with a human readable message."
70
- return None
85
+ return self._model
71
86
 
72
87
  async def _run(self, input: Dict | str) -> Dict | str:
73
88
  prompt = self.build_prompt()
@@ -76,7 +91,8 @@ class LangChainPromptAdapter(BaseAdapter):
76
91
  SystemMessage(content=prompt),
77
92
  HumanMessage(content=user_msg),
78
93
  ]
79
- response = self.model.invoke(messages)
94
+ model = await self.model()
95
+ response = model.invoke(messages)
80
96
 
81
97
  if self.has_structured_output():
82
98
  if (
@@ -2,9 +2,10 @@ import os
2
2
  from dataclasses import dataclass
3
3
  from enum import Enum
4
4
  from os import getenv
5
- from typing import Dict, List, NoReturn
5
+ from typing import Any, Dict, List, NoReturn
6
6
 
7
7
  import httpx
8
+ import requests
8
9
  from langchain_aws import ChatBedrockConverse
9
10
  from langchain_core.language_models.chat_models import BaseChatModel
10
11
  from langchain_groq import ChatGroq
@@ -43,6 +44,8 @@ class ModelFamily(str, Enum):
43
44
  phi = "phi"
44
45
  mistral = "mistral"
45
46
  gemma = "gemma"
47
+ gemini = "gemini"
48
+ claude = "claude"
46
49
 
47
50
 
48
51
  # Where models have instruct and raw versions, instruct is default and raw is specified
@@ -55,6 +58,9 @@ class ModelName(str, Enum):
55
58
  llama_3_1_8b = "llama_3_1_8b"
56
59
  llama_3_1_70b = "llama_3_1_70b"
57
60
  llama_3_1_405b = "llama_3_1_405b"
61
+ llama_3_2_3b = "llama_3_2_3b"
62
+ llama_3_2_11b = "llama_3_2_11b"
63
+ llama_3_2_90b = "llama_3_2_90b"
58
64
  gpt_4o_mini = "gpt_4o_mini"
59
65
  gpt_4o = "gpt_4o"
60
66
  phi_3_5 = "phi_3_5"
@@ -63,6 +69,11 @@ class ModelName(str, Enum):
63
69
  gemma_2_2b = "gemma_2_2b"
64
70
  gemma_2_9b = "gemma_2_9b"
65
71
  gemma_2_27b = "gemma_2_27b"
72
+ claude_3_5_sonnet = "claude_3_5_sonnet"
73
+ gemini_1_5_flash = "gemini_1_5_flash"
74
+ gemini_1_5_flash_8b = "gemini_1_5_flash_8b"
75
+ gemini_1_5_pro = "gemini_1_5_pro"
76
+ nemotron_70b = "nemotron_70b"
66
77
 
67
78
 
68
79
  class KilnModelProvider(BaseModel):
@@ -132,6 +143,67 @@ built_in_models: List[KilnModel] = [
132
143
  ),
133
144
  ],
134
145
  ),
146
+ # Claude 3.5 Sonnet
147
+ KilnModel(
148
+ family=ModelFamily.claude,
149
+ name=ModelName.claude_3_5_sonnet,
150
+ friendly_name="Claude 3.5 Sonnet",
151
+ providers=[
152
+ KilnModelProvider(
153
+ name=ModelProviderName.openrouter,
154
+ provider_options={"model": "anthropic/claude-3.5-sonnet"},
155
+ ),
156
+ ],
157
+ ),
158
+ # Gemini 1.5 Pro
159
+ KilnModel(
160
+ family=ModelFamily.gemini,
161
+ name=ModelName.gemini_1_5_pro,
162
+ friendly_name="Gemini 1.5 Pro",
163
+ providers=[
164
+ KilnModelProvider(
165
+ name=ModelProviderName.openrouter,
166
+ provider_options={"model": "google/gemini-pro-1.5"},
167
+ ),
168
+ ],
169
+ ),
170
+ # Gemini 1.5 Flash
171
+ KilnModel(
172
+ family=ModelFamily.gemini,
173
+ name=ModelName.gemini_1_5_flash,
174
+ friendly_name="Gemini 1.5 Flash",
175
+ providers=[
176
+ KilnModelProvider(
177
+ name=ModelProviderName.openrouter,
178
+ provider_options={"model": "google/gemini-flash-1.5"},
179
+ ),
180
+ ],
181
+ ),
182
+ # Gemini 1.5 Flash 8B
183
+ KilnModel(
184
+ family=ModelFamily.gemini,
185
+ name=ModelName.gemini_1_5_flash_8b,
186
+ friendly_name="Gemini 1.5 Flash 8B",
187
+ providers=[
188
+ KilnModelProvider(
189
+ name=ModelProviderName.openrouter,
190
+ provider_options={"model": "google/gemini-flash-1.5-8b"},
191
+ ),
192
+ ],
193
+ ),
194
+ # Nemotron 70B
195
+ KilnModel(
196
+ family=ModelFamily.llama,
197
+ name=ModelName.nemotron_70b,
198
+ friendly_name="Nemotron 70B",
199
+ providers=[
200
+ KilnModelProvider(
201
+ name=ModelProviderName.openrouter,
202
+ supports_structured_output=False,
203
+ provider_options={"model": "nvidia/llama-3.1-nemotron-70b-instruct"},
204
+ ),
205
+ ],
206
+ ),
135
207
  # Llama 3.1-8b
136
208
  KilnModel(
137
209
  family=ModelFamily.llama,
@@ -144,6 +216,7 @@ built_in_models: List[KilnModel] = [
144
216
  ),
145
217
  KilnModelProvider(
146
218
  name=ModelProviderName.amazon_bedrock,
219
+ supports_structured_output=False,
147
220
  provider_options={
148
221
  "model": "meta.llama3-1-8b-instruct-v1:0",
149
222
  "region_name": "us-west-2", # Llama 3.1 only in west-2
@@ -151,10 +224,14 @@ built_in_models: List[KilnModel] = [
151
224
  ),
152
225
  KilnModelProvider(
153
226
  name=ModelProviderName.ollama,
154
- provider_options={"model": "llama3.1"}, # 8b is default
227
+ provider_options={
228
+ "model": "llama3.1:8b",
229
+ "model_aliases": ["llama3.1"], # 8b is default
230
+ },
155
231
  ),
156
232
  KilnModelProvider(
157
233
  name=ModelProviderName.openrouter,
234
+ supports_structured_output=False,
158
235
  provider_options={"model": "meta-llama/llama-3.1-8b-instruct"},
159
236
  ),
160
237
  ],
@@ -171,7 +248,6 @@ built_in_models: List[KilnModel] = [
171
248
  ),
172
249
  KilnModelProvider(
173
250
  name=ModelProviderName.amazon_bedrock,
174
- # TODO: this should work but a bug in the bedrock response schema
175
251
  supports_structured_output=False,
176
252
  provider_options={
177
253
  "model": "meta.llama3-1-70b-instruct-v1:0",
@@ -182,11 +258,10 @@ built_in_models: List[KilnModel] = [
182
258
  name=ModelProviderName.openrouter,
183
259
  provider_options={"model": "meta-llama/llama-3.1-70b-instruct"},
184
260
  ),
185
- # TODO: enable once tests update to check if model is available
186
- # KilnModelProvider(
187
- # provider=ModelProviders.ollama,
188
- # provider_options={"model": "llama3.1:70b"},
189
- # ),
261
+ KilnModelProvider(
262
+ name=ModelProviderName.ollama,
263
+ provider_options={"model": "llama3.1:70b"},
264
+ ),
190
265
  ],
191
266
  ),
192
267
  # Llama 3.1 405b
@@ -195,11 +270,6 @@ built_in_models: List[KilnModel] = [
195
270
  name=ModelName.llama_3_1_405b,
196
271
  friendly_name="Llama 3.1 405B",
197
272
  providers=[
198
- # TODO: bring back when groq does: https://console.groq.com/docs/models
199
- # KilnModelProvider(
200
- # name=ModelProviderName.groq,
201
- # provider_options={"model": "llama-3.1-405b-instruct-v1:0"},
202
- # ),
203
273
  KilnModelProvider(
204
274
  name=ModelProviderName.amazon_bedrock,
205
275
  provider_options={
@@ -207,11 +277,10 @@ built_in_models: List[KilnModel] = [
207
277
  "region_name": "us-west-2", # Llama 3.1 only in west-2
208
278
  },
209
279
  ),
210
- # TODO: enable once tests update to check if model is available
211
- # KilnModelProvider(
212
- # name=ModelProviderName.ollama,
213
- # provider_options={"model": "llama3.1:405b"},
214
- # ),
280
+ KilnModelProvider(
281
+ name=ModelProviderName.ollama,
282
+ provider_options={"model": "llama3.1:405b"},
283
+ ),
215
284
  KilnModelProvider(
216
285
  name=ModelProviderName.openrouter,
217
286
  provider_options={"model": "meta-llama/llama-3.1-405b-instruct"},
@@ -247,11 +316,49 @@ built_in_models: List[KilnModel] = [
247
316
  name=ModelProviderName.openrouter,
248
317
  provider_options={"model": "mistralai/mistral-large"},
249
318
  ),
250
- # TODO: enable once tests update to check if model is available
251
- # KilnModelProvider(
252
- # provider=ModelProviders.ollama,
253
- # provider_options={"model": "mistral-large"},
254
- # ),
319
+ KilnModelProvider(
320
+ name=ModelProviderName.ollama,
321
+ provider_options={"model": "mistral-large"},
322
+ ),
323
+ ],
324
+ ),
325
+ # Llama 3.2 3B
326
+ KilnModel(
327
+ family=ModelFamily.llama,
328
+ name=ModelName.llama_3_2_3b,
329
+ friendly_name="Llama 3.2 3B",
330
+ providers=[
331
+ KilnModelProvider(
332
+ name=ModelProviderName.openrouter,
333
+ supports_structured_output=False,
334
+ provider_options={"model": "meta-llama/llama-3.2-3b-instruct"},
335
+ ),
336
+ ],
337
+ ),
338
+ # Llama 3.2 11B
339
+ KilnModel(
340
+ family=ModelFamily.llama,
341
+ name=ModelName.llama_3_2_11b,
342
+ friendly_name="Llama 3.2 11B",
343
+ providers=[
344
+ KilnModelProvider(
345
+ name=ModelProviderName.openrouter,
346
+ supports_structured_output=False,
347
+ provider_options={"model": "meta-llama/llama-3.2-11b-vision-instruct"},
348
+ ),
349
+ ],
350
+ ),
351
+ # Llama 3.2 90B
352
+ KilnModel(
353
+ family=ModelFamily.llama,
354
+ name=ModelName.llama_3_2_90b,
355
+ friendly_name="Llama 3.2 90B",
356
+ providers=[
357
+ KilnModelProvider(
358
+ name=ModelProviderName.openrouter,
359
+ supports_structured_output=False,
360
+ provider_options={"model": "meta-llama/llama-3.2-90b-vision-instruct"},
361
+ ),
255
362
  ],
256
363
  ),
257
364
  # Phi 3.5
@@ -263,6 +370,7 @@ built_in_models: List[KilnModel] = [
263
370
  providers=[
264
371
  KilnModelProvider(
265
372
  name=ModelProviderName.ollama,
373
+ supports_structured_output=False,
266
374
  provider_options={"model": "phi3.5"},
267
375
  ),
268
376
  KilnModelProvider(
@@ -280,6 +388,7 @@ built_in_models: List[KilnModel] = [
280
388
  providers=[
281
389
  KilnModelProvider(
282
390
  name=ModelProviderName.ollama,
391
+ supports_structured_output=False,
283
392
  provider_options={
284
393
  "model": "gemma2:2b",
285
394
  },
@@ -293,13 +402,12 @@ built_in_models: List[KilnModel] = [
293
402
  friendly_name="Gemma 2 9B",
294
403
  supports_structured_output=False,
295
404
  providers=[
296
- # TODO: enable once tests update to check if model is available
297
- # KilnModelProvider(
298
- # name=ModelProviderName.ollama,
299
- # provider_options={
300
- # "model": "gemma2:9b",
301
- # },
302
- # ),
405
+ KilnModelProvider(
406
+ name=ModelProviderName.ollama,
407
+ provider_options={
408
+ "model": "gemma2:9b",
409
+ },
410
+ ),
303
411
  KilnModelProvider(
304
412
  name=ModelProviderName.openrouter,
305
413
  provider_options={"model": "google/gemma-2-9b-it"},
@@ -313,13 +421,12 @@ built_in_models: List[KilnModel] = [
313
421
  friendly_name="Gemma 2 27B",
314
422
  supports_structured_output=False,
315
423
  providers=[
316
- # TODO: enable once tests update to check if model is available
317
- # KilnModelProvider(
318
- # name=ModelProviderName.ollama,
319
- # provider_options={
320
- # "model": "gemma2:27b",
321
- # },
322
- # ),
424
+ KilnModelProvider(
425
+ name=ModelProviderName.ollama,
426
+ provider_options={
427
+ "model": "gemma2:27b",
428
+ },
429
+ ),
323
430
  KilnModelProvider(
324
431
  name=ModelProviderName.openrouter,
325
432
  provider_options={"model": "google/gemma-2-27b-it"},
@@ -417,7 +524,9 @@ def check_provider_warnings(provider_name: ModelProviderName):
417
524
  raise ValueError(warning_check.message)
418
525
 
419
526
 
420
- def langchain_model_from(name: str, provider_name: str | None = None) -> BaseChatModel:
527
+ async def langchain_model_from(
528
+ name: str, provider_name: str | None = None
529
+ ) -> BaseChatModel:
421
530
  """
422
531
  Creates a LangChain chat model instance for the specified model and provider.
423
532
 
@@ -476,7 +585,23 @@ def langchain_model_from(name: str, provider_name: str | None = None) -> BaseCha
476
585
  **provider.provider_options,
477
586
  )
478
587
  elif provider.name == ModelProviderName.ollama:
479
- return ChatOllama(**provider.provider_options, base_url=ollama_base_url())
588
+ # Ollama model naming is pretty flexible. We try a few versions of the model name
589
+ potential_model_names = []
590
+ if "model" in provider.provider_options:
591
+ potential_model_names.append(provider.provider_options["model"])
592
+ if "model_aliases" in provider.provider_options:
593
+ potential_model_names.extend(provider.provider_options["model_aliases"])
594
+
595
+ # Get the list of models Ollama supports
596
+ ollama_connection = await get_ollama_connection()
597
+ if ollama_connection is None:
598
+ raise ValueError("Failed to connect to Ollama. Ensure Ollama is running.")
599
+
600
+ for model_name in potential_model_names:
601
+ if ollama_model_supported(ollama_connection, model_name):
602
+ return ChatOllama(model=model_name, base_url=ollama_base_url())
603
+
604
+ raise ValueError(f"Model {name} not installed on Ollama")
480
605
  elif provider.name == ModelProviderName.openrouter:
481
606
  api_key = Config.shared().open_router_api_key
482
607
  base_url = getenv("OPENROUTER_BASE_URL") or "https://openrouter.ai/api/v1"
@@ -519,3 +644,67 @@ async def ollama_online() -> bool:
519
644
  except httpx.RequestError:
520
645
  return False
521
646
  return True
647
+
648
+
649
+ class OllamaConnection(BaseModel):
650
+ message: str
651
+ models: List[str]
652
+
653
+
654
+ # Parse the Ollama /api/tags response
655
+ def parse_ollama_tags(tags: Any) -> OllamaConnection | None:
656
+ # Build a list of models we support for Ollama from the built-in model list
657
+ supported_ollama_models = [
658
+ provider.provider_options["model"]
659
+ for model in built_in_models
660
+ for provider in model.providers
661
+ if provider.name == ModelProviderName.ollama
662
+ ]
663
+ # Append model_aliases to supported_ollama_models
664
+ supported_ollama_models.extend(
665
+ [
666
+ alias
667
+ for model in built_in_models
668
+ for provider in model.providers
669
+ for alias in provider.provider_options.get("model_aliases", [])
670
+ ]
671
+ )
672
+
673
+ if "models" in tags:
674
+ models = tags["models"]
675
+ if isinstance(models, list):
676
+ model_names = [model["model"] for model in models]
677
+ print(f"model_names: {model_names}")
678
+ available_supported_models = [
679
+ model
680
+ for model in model_names
681
+ if model in supported_ollama_models
682
+ or model in [f"{m}:latest" for m in supported_ollama_models]
683
+ ]
684
+ if available_supported_models:
685
+ return OllamaConnection(
686
+ message="Ollama connected",
687
+ models=available_supported_models,
688
+ )
689
+
690
+ return OllamaConnection(
691
+ message="Ollama is running, but no supported models are installed. Install one or more supported model, like 'ollama pull phi3.5'.",
692
+ models=[],
693
+ )
694
+
695
+
696
+ async def get_ollama_connection() -> OllamaConnection | None:
697
+ """
698
+ Gets the connection status for Ollama.
699
+ """
700
+ try:
701
+ tags = requests.get(ollama_base_url() + "/api/tags", timeout=5).json()
702
+
703
+ except Exception:
704
+ return None
705
+
706
+ return parse_ollama_tags(tags)
707
+
708
+
709
+ def ollama_model_supported(conn: OllamaConnection, model_name: str) -> bool:
710
+ return model_name in conn.models or f"{model_name}:latest" in conn.models
@@ -1,10 +1,14 @@
1
+ import json
1
2
  from unittest.mock import patch
2
3
 
3
4
  import pytest
4
5
 
5
6
  from kiln_ai.adapters.ml_model_list import (
6
7
  ModelProviderName,
8
+ OllamaConnection,
7
9
  check_provider_warnings,
10
+ ollama_model_supported,
11
+ parse_ollama_tags,
8
12
  provider_name_from_id,
9
13
  provider_warnings,
10
14
  )
@@ -97,3 +101,25 @@ def test_provider_name_from_id_case_sensitivity():
97
101
  )
98
102
  def test_provider_name_from_id_parametrized(provider_id, expected_name):
99
103
  assert provider_name_from_id(provider_id) == expected_name
104
+
105
+
106
+ def test_parse_ollama_tags_no_models():
107
+ json_response = '{"models":[{"name":"phi3.5:latest","model":"phi3.5:latest","modified_at":"2024-10-02T12:04:35.191519822-04:00","size":2176178843,"digest":"61819fb370a3c1a9be6694869331e5f85f867a079e9271d66cb223acb81d04ba","details":{"parent_model":"","format":"gguf","family":"phi3","families":["phi3"],"parameter_size":"3.8B","quantization_level":"Q4_0"}},{"name":"gemma2:2b","model":"gemma2:2b","modified_at":"2024-09-09T16:46:38.64348929-04:00","size":1629518495,"digest":"8ccf136fdd5298f3ffe2d69862750ea7fb56555fa4d5b18c04e3fa4d82ee09d7","details":{"parent_model":"","format":"gguf","family":"gemma2","families":["gemma2"],"parameter_size":"2.6B","quantization_level":"Q4_0"}},{"name":"llama3.1:latest","model":"llama3.1:latest","modified_at":"2024-09-01T17:19:43.481523695-04:00","size":4661230720,"digest":"f66fc8dc39ea206e03ff6764fcc696b1b4dfb693f0b6ef751731dd4e6269046e","details":{"parent_model":"","format":"gguf","family":"llama","families":["llama"],"parameter_size":"8.0B","quantization_level":"Q4_0"}}]}'
108
+ tags = json.loads(json_response)
109
+ print(json.dumps(tags, indent=2))
110
+ conn = parse_ollama_tags(tags)
111
+ assert "phi3.5:latest" in conn.models
112
+ assert "gemma2:2b" in conn.models
113
+ assert "llama3.1:latest" in conn.models
114
+
115
+
116
+ def test_ollama_model_supported():
117
+ conn = OllamaConnection(
118
+ models=["phi3.5:latest", "gemma2:2b", "llama3.1:latest"], message="Connected"
119
+ )
120
+ assert ollama_model_supported(conn, "phi3.5:latest")
121
+ assert ollama_model_supported(conn, "phi3.5")
122
+ assert ollama_model_supported(conn, "gemma2:2b")
123
+ assert ollama_model_supported(conn, "llama3.1:latest")
124
+ assert ollama_model_supported(conn, "llama3.1")
125
+ assert not ollama_model_supported(conn, "unknown_model")
@@ -16,9 +16,23 @@ async def test_groq(tmp_path):
16
16
  await run_simple_test(tmp_path, "llama_3_1_8b", "groq")
17
17
 
18
18
 
19
+ @pytest.mark.parametrize(
20
+ "model_name",
21
+ [
22
+ "llama_3_1_8b",
23
+ "llama_3_1_70b",
24
+ "gemini_1_5_pro",
25
+ "gemini_1_5_flash",
26
+ "gemini_1_5_flash_8b",
27
+ "nemotron_70b",
28
+ "llama_3_2_3b",
29
+ "llama_3_2_11b",
30
+ "llama_3_2_90b",
31
+ ],
32
+ )
19
33
  @pytest.mark.paid
20
- async def test_openrouter(tmp_path):
21
- await run_simple_test(tmp_path, "llama_3_1_8b", "openrouter")
34
+ async def test_openrouter(tmp_path, model_name):
35
+ await run_simple_test(tmp_path, model_name, "openrouter")
22
36
 
23
37
 
24
38
  @pytest.mark.ollama
@@ -15,19 +15,21 @@ from kiln_ai.adapters.ml_model_list import (
15
15
  from kiln_ai.datamodel.test_json_schema import json_joke_schema, json_triangle_schema
16
16
 
17
17
 
18
+ @pytest.mark.parametrize(
19
+ "model_name,provider",
20
+ [
21
+ ("llama_3_1_8b", "groq"),
22
+ ("mistral_nemo", "openrouter"),
23
+ ("llama_3_1_70b", "amazon_bedrock"),
24
+ ("claude_3_5_sonnet", "openrouter"),
25
+ ("gemini_1_5_pro", "openrouter"),
26
+ ("gemini_1_5_flash", "openrouter"),
27
+ ("gemini_1_5_flash_8b", "openrouter"),
28
+ ],
29
+ )
18
30
  @pytest.mark.paid
19
- async def test_structured_output_groq(tmp_path):
20
- await run_structured_output_test(tmp_path, "llama_3_1_8b", "groq")
21
-
22
-
23
- @pytest.mark.paid
24
- async def test_structured_output_openrouter(tmp_path):
25
- await run_structured_output_test(tmp_path, "mistral_nemo", "openrouter")
26
-
27
-
28
- @pytest.mark.paid
29
- async def test_structured_output_bedrock(tmp_path):
30
- await run_structured_output_test(tmp_path, "llama_3_1_70b", "amazon_bedrock")
31
+ async def test_structured_output(tmp_path, model_name, provider):
32
+ await run_structured_output_test(tmp_path, model_name, provider)
31
33
 
32
34
 
33
35
  @pytest.mark.ollama
@@ -39,16 +41,17 @@ async def test_structured_output_ollama_phi(tmp_path):
39
41
  await run_structured_output_test(tmp_path, "phi_3_5", "ollama")
40
42
 
41
43
 
42
- @pytest.mark.ollama
44
+ @pytest.mark.paid
43
45
  async def test_structured_output_gpt_4o_mini(tmp_path):
44
46
  await run_structured_output_test(tmp_path, "gpt_4o_mini", "openai")
45
47
 
46
48
 
49
+ @pytest.mark.parametrize("model_name", ["llama_3_1_8b"])
47
50
  @pytest.mark.ollama
48
- async def test_structured_output_ollama_llama(tmp_path):
51
+ async def test_structured_output_ollama_llama(tmp_path, model_name):
49
52
  if not await ollama_online():
50
53
  pytest.skip("Ollama API not running. Expect it running on localhost:11434")
51
- await run_structured_output_test(tmp_path, "llama_3_1_8b", "ollama")
54
+ await run_structured_output_test(tmp_path, model_name, "ollama")
52
55
 
53
56
 
54
57
  class MockAdapter(BaseAdapter):
@@ -105,6 +108,7 @@ async def test_mock_unstructred_response(tmp_path):
105
108
  @pytest.mark.paid
106
109
  @pytest.mark.ollama
107
110
  async def test_all_built_in_models_structured_output(tmp_path):
111
+ errors = []
108
112
  for model in built_in_models:
109
113
  if not model.supports_structured_output:
110
114
  print(
@@ -121,7 +125,10 @@ async def test_all_built_in_models_structured_output(tmp_path):
121
125
  print(f"Running {model.name} {provider.name}")
122
126
  await run_structured_output_test(tmp_path, model.name, provider.name)
123
127
  except Exception as e:
124
- raise RuntimeError(f"Error running {model.name} {provider}") from e
128
+ print(f"Error running {model.name} {provider.name}")
129
+ errors.append(f"{model.name} {provider.name}: {e}")
130
+ if len(errors) > 0:
131
+ raise RuntimeError(f"Errors: {errors}")
125
132
 
126
133
 
127
134
  def build_structured_output_test_task(tmp_path: Path):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: kiln-ai
3
- Version: 0.5.2
3
+ Version: 0.5.4
4
4
  Summary: Kiln AI
5
5
  Project-URL: Homepage, https://getkiln.ai
6
6
  Project-URL: Repository, https://github.com/Kiln-AI/kiln
@@ -8,6 +8,12 @@ Project-URL: Documentation, https://kiln-ai.github.io/Kiln/kiln_core_docs/kiln_a
8
8
  Project-URL: Issues, https://github.com/Kiln-AI/kiln/issues
9
9
  Author-email: "Steve Cosman, Chesterfield Laboratories Inc" <scosman@users.noreply.github.com>
10
10
  License-File: LICENSE.txt
11
+ Classifier: Intended Audience :: Developers
12
+ Classifier: License :: OSI Approved :: MIT License
13
+ Classifier: Programming Language :: Python :: 3.10
14
+ Classifier: Programming Language :: Python :: 3.11
15
+ Classifier: Programming Language :: Python :: 3.12
16
+ Classifier: Programming Language :: Python :: 3.13
11
17
  Requires-Python: >=3.10
12
18
  Requires-Dist: coverage>=7.6.4
13
19
  Requires-Dist: jsonschema>=4.23.0
@@ -1,15 +1,15 @@
1
1
  kiln_ai/__init__.py,sha256=Sc4z8LRVFMwJUoc_DPVUriSXTZ6PO9MaJ80PhRbKyB8,34
2
2
  kiln_ai/adapters/__init__.py,sha256=3NC1lE_Sg1bF4IsKCoUgje2GL0IwTd1atw1BcDLI8IA,883
3
3
  kiln_ai/adapters/base_adapter.py,sha256=xXCISAJHaPCYHad28CS0wZEUlx711FZ_6AwW4rJx4jk,6688
4
- kiln_ai/adapters/langchain_adapters.py,sha256=Fo7w7hWdkxuuvxoNZhcGE25tOS6ObzhKEUKGszzPdtk,4929
5
- kiln_ai/adapters/ml_model_list.py,sha256=J13pDFp6UwTd7sa_kZyesRaqR1rin4U1iYMc5NYF05Q,17507
4
+ kiln_ai/adapters/langchain_adapters.py,sha256=WNxhuTdjGCsCyqmXJNLe7HJ-MzJ08yagGV-eAHPZF-E,5411
5
+ kiln_ai/adapters/ml_model_list.py,sha256=mqtFCav-m4m4DzxtoWHBktxrTB7haj0vUPx3xa8CL2U,23414
6
6
  kiln_ai/adapters/prompt_builders.py,sha256=nfZnEr1E30ZweQhEzIP21rNrL2Or1ILajyX8gU3B7w0,7796
7
7
  kiln_ai/adapters/test_langchain_adapter.py,sha256=_xHpVAkkoGh0PRO3BFFqvVj95SVtYZPOdFbYGYfzvQ0,1876
8
- kiln_ai/adapters/test_ml_model_list.py,sha256=BNuJSIegMMLzcICDR49qLFm7ezSl188LE4-W98c73tA,2786
9
- kiln_ai/adapters/test_prompt_adaptors.py,sha256=TXfSLfOHcg9EJINLfyJDQ-WcMw4He8ab4k-fGeryJcY,6033
8
+ kiln_ai/adapters/test_ml_model_list.py,sha256=XHbwEFFb7WvZ6UkArqIiQ_yhS_urezHtgvJOSnaricY,4660
9
+ kiln_ai/adapters/test_prompt_adaptors.py,sha256=z_X-REnWmTai23Ay_xSEB1qRKpp3HAeqgs4K6TQneb0,6332
10
10
  kiln_ai/adapters/test_prompt_builders.py,sha256=WmTR59tnKnKQ5gnX1X9EqvEUdQr0PQ8OvadYtRQR5sQ,11483
11
11
  kiln_ai/adapters/test_saving_adapter_results.py,sha256=tQvpLawo8mR2scPwmRCIz9Sp0ZkerS3kVJKBzlcjwRE,6041
12
- kiln_ai/adapters/test_structured_output.py,sha256=Okl6kLaAEKOuy1UBvQuiM5LGmJJi2aPB8sQR4bzIyIA,8755
12
+ kiln_ai/adapters/test_structured_output.py,sha256=Z9A2R-TC-2atsdr8sGVGDlJhfa7uytW8Xi8PKBdEEAw,9033
13
13
  kiln_ai/adapters/repair/__init__.py,sha256=dOO9MEpEhjiwzDVFg3MNfA2bKMPlax9iekDatpTkX8E,217
14
14
  kiln_ai/adapters/repair/repair_task.py,sha256=VXvX1l9AYDE_GV0i3S_vPThltJoQlCFVCCHV9m-QA7k,3297
15
15
  kiln_ai/adapters/repair/test_repair_task.py,sha256=12PHb4SgBvVdLUzjZz31M8OTa8D8QjHD0Du4s7ij-i8,7819
@@ -27,7 +27,7 @@ kiln_ai/utils/__init__.py,sha256=PTD0MwBCKAMIOGsTAwsFaJOusTJJoRFTfOGqRvCaU-E,142
27
27
  kiln_ai/utils/config.py,sha256=jXUB8lwFkxLNEaizwIsoeFLg1BwjWr39-5KdEGF37Bg,5424
28
28
  kiln_ai/utils/formatting.py,sha256=VtB9oag0lOGv17dwT7OPX_3HzBfaU9GsLH-iLete0yM,97
29
29
  kiln_ai/utils/test_config.py,sha256=lbN0NhgKPEZ0idaS-zTn6mWsSAV6omo32JcIy05h2-M,7411
30
- kiln_ai-0.5.2.dist-info/METADATA,sha256=yClipqFUcwHfRWVBH6oJlE2D0S2ENBk58gFTapZBCjk,1718
31
- kiln_ai-0.5.2.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
32
- kiln_ai-0.5.2.dist-info/licenses/LICENSE.txt,sha256=-AhuIX-CMdNGJNj74C29e9cKKmsh-1PBPINCsNvwAeg,82
33
- kiln_ai-0.5.2.dist-info/RECORD,,
30
+ kiln_ai-0.5.4.dist-info/METADATA,sha256=UGIFG4gewCNOmUQpxYq_3g_dzCfq3tLuLLTVCnrhbOY,2017
31
+ kiln_ai-0.5.4.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
32
+ kiln_ai-0.5.4.dist-info/licenses/LICENSE.txt,sha256=_NA5pnTYgRRr4qH6lE3X-TuZJ8iRcMUi5ASoGr-lEx8,1209
33
+ kiln_ai-0.5.4.dist-info/RECORD,,
@@ -0,0 +1,13 @@
1
+
2
+
3
+ This license applies only to the software in the libs/core directory.
4
+
5
+ =======================================================
6
+
7
+ Copyright 2024 - Chesterfield Laboratories Inc.
8
+
9
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
10
+
11
+ The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
12
+
13
+ THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
@@ -1,4 +0,0 @@
1
-
2
- All rights reserved.
3
-
4
- Copyright (c) Steve Cosman, Chesterfield Laboratories Inc.