kiln-ai 0.5.3__py3-none-any.whl → 0.5.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

@@ -1,16 +1,23 @@
1
1
  from typing import Dict
2
2
 
3
+ from langchain_core.language_models import LanguageModelInput
3
4
  from langchain_core.language_models.chat_models import BaseChatModel
4
5
  from langchain_core.messages import HumanMessage, SystemMessage
5
6
  from langchain_core.messages.base import BaseMessage
7
+ from langchain_core.runnables import Runnable
8
+ from pydantic import BaseModel
6
9
 
7
10
  import kiln_ai.datamodel as datamodel
8
11
 
9
12
  from .base_adapter import AdapterInfo, BaseAdapter, BasePromptBuilder
10
13
  from .ml_model_list import langchain_model_from
11
14
 
15
+ LangChainModelType = BaseChatModel | Runnable[LanguageModelInput, Dict | BaseModel]
16
+
12
17
 
13
18
  class LangChainPromptAdapter(BaseAdapter):
19
+ _model: LangChainModelType | None = None
20
+
14
21
  def __init__(
15
22
  self,
16
23
  kiln_task: datamodel.Task,
@@ -21,7 +28,7 @@ class LangChainPromptAdapter(BaseAdapter):
21
28
  ):
22
29
  super().__init__(kiln_task, prompt_builder=prompt_builder)
23
30
  if custom_model is not None:
24
- self.model = custom_model
31
+ self._model = custom_model
25
32
 
26
33
  # Attempt to infer model provider and name from custom model
27
34
  self.model_provider = "custom.langchain:" + custom_model.__class__.__name__
@@ -37,19 +44,32 @@ class LangChainPromptAdapter(BaseAdapter):
37
44
  ):
38
45
  self.model_name = "custom.langchain:" + getattr(custom_model, "model")
39
46
  elif model_name is not None:
40
- self.model = langchain_model_from(model_name, provider)
41
47
  self.model_name = model_name
42
48
  self.model_provider = provider or "custom.langchain.default_provider"
43
49
  else:
44
50
  raise ValueError(
45
51
  "model_name and provider must be provided if custom_model is not provided"
46
52
  )
53
+
54
+ def adapter_specific_instructions(self) -> str | None:
55
+ # TODO: would be better to explicitly use bind_tools:tool_choice="task_response" here
47
56
  if self.has_structured_output():
48
- if not hasattr(self.model, "with_structured_output") or not callable(
49
- getattr(self.model, "with_structured_output")
57
+ return "Always respond with a tool call. Never respond with a human readable message."
58
+ return None
59
+
60
+ async def model(self) -> LangChainModelType:
61
+ # cached model
62
+ if self._model:
63
+ return self._model
64
+
65
+ self._model = await langchain_model_from(self.model_name, self.model_provider)
66
+
67
+ if self.has_structured_output():
68
+ if not hasattr(self._model, "with_structured_output") or not callable(
69
+ getattr(self._model, "with_structured_output")
50
70
  ):
51
71
  raise ValueError(
52
- f"model {self.model} does not support structured output, cannot use output_json_schema"
72
+ f"model {self._model} does not support structured output, cannot use output_json_schema"
53
73
  )
54
74
  # Langchain expects title/description to be at top level, on top of json schema
55
75
  output_schema = self.kiln_task.output_schema()
@@ -59,15 +79,10 @@ class LangChainPromptAdapter(BaseAdapter):
59
79
  )
60
80
  output_schema["title"] = "task_response"
61
81
  output_schema["description"] = "A response from the task"
62
- self.model = self.model.with_structured_output(
82
+ self._model = self._model.with_structured_output(
63
83
  output_schema, include_raw=True
64
84
  )
65
-
66
- def adapter_specific_instructions(self) -> str | None:
67
- # TODO: would be better to explicitly use bind_tools:tool_choice="task_response" here
68
- if self.has_structured_output():
69
- return "Always respond with a tool call. Never respond with a human readable message."
70
- return None
85
+ return self._model
71
86
 
72
87
  async def _run(self, input: Dict | str) -> Dict | str:
73
88
  prompt = self.build_prompt()
@@ -76,7 +91,8 @@ class LangChainPromptAdapter(BaseAdapter):
76
91
  SystemMessage(content=prompt),
77
92
  HumanMessage(content=user_msg),
78
93
  ]
79
- response = self.model.invoke(messages)
94
+ model = await self.model()
95
+ response = model.invoke(messages)
80
96
 
81
97
  if self.has_structured_output():
82
98
  if (
@@ -2,9 +2,10 @@ import os
2
2
  from dataclasses import dataclass
3
3
  from enum import Enum
4
4
  from os import getenv
5
- from typing import Dict, List, NoReturn
5
+ from typing import Any, Dict, List, NoReturn
6
6
 
7
7
  import httpx
8
+ import requests
8
9
  from langchain_aws import ChatBedrockConverse
9
10
  from langchain_core.language_models.chat_models import BaseChatModel
10
11
  from langchain_groq import ChatGroq
@@ -43,6 +44,8 @@ class ModelFamily(str, Enum):
43
44
  phi = "phi"
44
45
  mistral = "mistral"
45
46
  gemma = "gemma"
47
+ gemini = "gemini"
48
+ claude = "claude"
46
49
 
47
50
 
48
51
  # Where models have instruct and raw versions, instruct is default and raw is specified
@@ -55,6 +58,9 @@ class ModelName(str, Enum):
55
58
  llama_3_1_8b = "llama_3_1_8b"
56
59
  llama_3_1_70b = "llama_3_1_70b"
57
60
  llama_3_1_405b = "llama_3_1_405b"
61
+ llama_3_2_3b = "llama_3_2_3b"
62
+ llama_3_2_11b = "llama_3_2_11b"
63
+ llama_3_2_90b = "llama_3_2_90b"
58
64
  gpt_4o_mini = "gpt_4o_mini"
59
65
  gpt_4o = "gpt_4o"
60
66
  phi_3_5 = "phi_3_5"
@@ -63,6 +69,12 @@ class ModelName(str, Enum):
63
69
  gemma_2_2b = "gemma_2_2b"
64
70
  gemma_2_9b = "gemma_2_9b"
65
71
  gemma_2_27b = "gemma_2_27b"
72
+ claude_3_5_haiku = "claude_3_5_haiku"
73
+ claude_3_5_sonnet = "claude_3_5_sonnet"
74
+ gemini_1_5_flash = "gemini_1_5_flash"
75
+ gemini_1_5_flash_8b = "gemini_1_5_flash_8b"
76
+ gemini_1_5_pro = "gemini_1_5_pro"
77
+ nemotron_70b = "nemotron_70b"
66
78
 
67
79
 
68
80
  class KilnModelProvider(BaseModel):
@@ -132,6 +144,79 @@ built_in_models: List[KilnModel] = [
132
144
  ),
133
145
  ],
134
146
  ),
147
+ # Claude 3.5 Haiku
148
+ KilnModel(
149
+ family=ModelFamily.claude,
150
+ name=ModelName.claude_3_5_haiku,
151
+ friendly_name="Claude 3.5 Haiku",
152
+ providers=[
153
+ KilnModelProvider(
154
+ name=ModelProviderName.openrouter,
155
+ provider_options={"model": "anthropic/claude-3-5-haiku"},
156
+ ),
157
+ ],
158
+ ),
159
+ # Claude 3.5 Sonnet
160
+ KilnModel(
161
+ family=ModelFamily.claude,
162
+ name=ModelName.claude_3_5_sonnet,
163
+ friendly_name="Claude 3.5 Sonnet",
164
+ providers=[
165
+ KilnModelProvider(
166
+ name=ModelProviderName.openrouter,
167
+ provider_options={"model": "anthropic/claude-3.5-sonnet"},
168
+ ),
169
+ ],
170
+ ),
171
+ # Gemini 1.5 Pro
172
+ KilnModel(
173
+ family=ModelFamily.gemini,
174
+ name=ModelName.gemini_1_5_pro,
175
+ friendly_name="Gemini 1.5 Pro",
176
+ providers=[
177
+ KilnModelProvider(
178
+ name=ModelProviderName.openrouter,
179
+ provider_options={"model": "google/gemini-pro-1.5"},
180
+ ),
181
+ ],
182
+ ),
183
+ # Gemini 1.5 Flash
184
+ KilnModel(
185
+ family=ModelFamily.gemini,
186
+ name=ModelName.gemini_1_5_flash,
187
+ friendly_name="Gemini 1.5 Flash",
188
+ providers=[
189
+ KilnModelProvider(
190
+ name=ModelProviderName.openrouter,
191
+ provider_options={"model": "google/gemini-flash-1.5"},
192
+ ),
193
+ ],
194
+ ),
195
+ # Gemini 1.5 Flash 8B
196
+ KilnModel(
197
+ family=ModelFamily.gemini,
198
+ name=ModelName.gemini_1_5_flash_8b,
199
+ friendly_name="Gemini 1.5 Flash 8B",
200
+ providers=[
201
+ KilnModelProvider(
202
+ name=ModelProviderName.openrouter,
203
+ provider_options={"model": "google/gemini-flash-1.5-8b"},
204
+ ),
205
+ ],
206
+ ),
207
+ # Nemotron 70B
208
+ KilnModel(
209
+ family=ModelFamily.llama,
210
+ name=ModelName.nemotron_70b,
211
+ friendly_name="Nemotron 70B",
212
+ providers=[
213
+ KilnModelProvider(
214
+ name=ModelProviderName.openrouter,
215
+ supports_structured_output=False,
216
+ provider_options={"model": "nvidia/llama-3.1-nemotron-70b-instruct"},
217
+ ),
218
+ ],
219
+ ),
135
220
  # Llama 3.1-8b
136
221
  KilnModel(
137
222
  family=ModelFamily.llama,
@@ -144,6 +229,7 @@ built_in_models: List[KilnModel] = [
144
229
  ),
145
230
  KilnModelProvider(
146
231
  name=ModelProviderName.amazon_bedrock,
232
+ supports_structured_output=False,
147
233
  provider_options={
148
234
  "model": "meta.llama3-1-8b-instruct-v1:0",
149
235
  "region_name": "us-west-2", # Llama 3.1 only in west-2
@@ -151,10 +237,14 @@ built_in_models: List[KilnModel] = [
151
237
  ),
152
238
  KilnModelProvider(
153
239
  name=ModelProviderName.ollama,
154
- provider_options={"model": "llama3.1"}, # 8b is default
240
+ provider_options={
241
+ "model": "llama3.1:8b",
242
+ "model_aliases": ["llama3.1"], # 8b is default
243
+ },
155
244
  ),
156
245
  KilnModelProvider(
157
246
  name=ModelProviderName.openrouter,
247
+ supports_structured_output=False,
158
248
  provider_options={"model": "meta-llama/llama-3.1-8b-instruct"},
159
249
  ),
160
250
  ],
@@ -171,7 +261,6 @@ built_in_models: List[KilnModel] = [
171
261
  ),
172
262
  KilnModelProvider(
173
263
  name=ModelProviderName.amazon_bedrock,
174
- # TODO: this should work but a bug in the bedrock response schema
175
264
  supports_structured_output=False,
176
265
  provider_options={
177
266
  "model": "meta.llama3-1-70b-instruct-v1:0",
@@ -182,11 +271,10 @@ built_in_models: List[KilnModel] = [
182
271
  name=ModelProviderName.openrouter,
183
272
  provider_options={"model": "meta-llama/llama-3.1-70b-instruct"},
184
273
  ),
185
- # TODO: enable once tests update to check if model is available
186
- # KilnModelProvider(
187
- # provider=ModelProviders.ollama,
188
- # provider_options={"model": "llama3.1:70b"},
189
- # ),
274
+ KilnModelProvider(
275
+ name=ModelProviderName.ollama,
276
+ provider_options={"model": "llama3.1:70b"},
277
+ ),
190
278
  ],
191
279
  ),
192
280
  # Llama 3.1 405b
@@ -195,11 +283,6 @@ built_in_models: List[KilnModel] = [
195
283
  name=ModelName.llama_3_1_405b,
196
284
  friendly_name="Llama 3.1 405B",
197
285
  providers=[
198
- # TODO: bring back when groq does: https://console.groq.com/docs/models
199
- # KilnModelProvider(
200
- # name=ModelProviderName.groq,
201
- # provider_options={"model": "llama-3.1-405b-instruct-v1:0"},
202
- # ),
203
286
  KilnModelProvider(
204
287
  name=ModelProviderName.amazon_bedrock,
205
288
  provider_options={
@@ -207,11 +290,10 @@ built_in_models: List[KilnModel] = [
207
290
  "region_name": "us-west-2", # Llama 3.1 only in west-2
208
291
  },
209
292
  ),
210
- # TODO: enable once tests update to check if model is available
211
- # KilnModelProvider(
212
- # name=ModelProviderName.ollama,
213
- # provider_options={"model": "llama3.1:405b"},
214
- # ),
293
+ KilnModelProvider(
294
+ name=ModelProviderName.ollama,
295
+ provider_options={"model": "llama3.1:405b"},
296
+ ),
215
297
  KilnModelProvider(
216
298
  name=ModelProviderName.openrouter,
217
299
  provider_options={"model": "meta-llama/llama-3.1-405b-instruct"},
@@ -247,11 +329,49 @@ built_in_models: List[KilnModel] = [
247
329
  name=ModelProviderName.openrouter,
248
330
  provider_options={"model": "mistralai/mistral-large"},
249
331
  ),
250
- # TODO: enable once tests update to check if model is available
251
- # KilnModelProvider(
252
- # provider=ModelProviders.ollama,
253
- # provider_options={"model": "mistral-large"},
254
- # ),
332
+ KilnModelProvider(
333
+ name=ModelProviderName.ollama,
334
+ provider_options={"model": "mistral-large"},
335
+ ),
336
+ ],
337
+ ),
338
+ # Llama 3.2 3B
339
+ KilnModel(
340
+ family=ModelFamily.llama,
341
+ name=ModelName.llama_3_2_3b,
342
+ friendly_name="Llama 3.2 3B",
343
+ providers=[
344
+ KilnModelProvider(
345
+ name=ModelProviderName.openrouter,
346
+ supports_structured_output=False,
347
+ provider_options={"model": "meta-llama/llama-3.2-3b-instruct"},
348
+ ),
349
+ ],
350
+ ),
351
+ # Llama 3.2 11B
352
+ KilnModel(
353
+ family=ModelFamily.llama,
354
+ name=ModelName.llama_3_2_11b,
355
+ friendly_name="Llama 3.2 11B",
356
+ providers=[
357
+ KilnModelProvider(
358
+ name=ModelProviderName.openrouter,
359
+ supports_structured_output=False,
360
+ provider_options={"model": "meta-llama/llama-3.2-11b-vision-instruct"},
361
+ ),
362
+ ],
363
+ ),
364
+ # Llama 3.2 90B
365
+ KilnModel(
366
+ family=ModelFamily.llama,
367
+ name=ModelName.llama_3_2_90b,
368
+ friendly_name="Llama 3.2 90B",
369
+ providers=[
370
+ KilnModelProvider(
371
+ name=ModelProviderName.openrouter,
372
+ supports_structured_output=False,
373
+ provider_options={"model": "meta-llama/llama-3.2-90b-vision-instruct"},
374
+ ),
255
375
  ],
256
376
  ),
257
377
  # Phi 3.5
@@ -263,6 +383,7 @@ built_in_models: List[KilnModel] = [
263
383
  providers=[
264
384
  KilnModelProvider(
265
385
  name=ModelProviderName.ollama,
386
+ supports_structured_output=False,
266
387
  provider_options={"model": "phi3.5"},
267
388
  ),
268
389
  KilnModelProvider(
@@ -280,6 +401,7 @@ built_in_models: List[KilnModel] = [
280
401
  providers=[
281
402
  KilnModelProvider(
282
403
  name=ModelProviderName.ollama,
404
+ supports_structured_output=False,
283
405
  provider_options={
284
406
  "model": "gemma2:2b",
285
407
  },
@@ -293,13 +415,12 @@ built_in_models: List[KilnModel] = [
293
415
  friendly_name="Gemma 2 9B",
294
416
  supports_structured_output=False,
295
417
  providers=[
296
- # TODO: enable once tests update to check if model is available
297
- # KilnModelProvider(
298
- # name=ModelProviderName.ollama,
299
- # provider_options={
300
- # "model": "gemma2:9b",
301
- # },
302
- # ),
418
+ KilnModelProvider(
419
+ name=ModelProviderName.ollama,
420
+ provider_options={
421
+ "model": "gemma2:9b",
422
+ },
423
+ ),
303
424
  KilnModelProvider(
304
425
  name=ModelProviderName.openrouter,
305
426
  provider_options={"model": "google/gemma-2-9b-it"},
@@ -313,13 +434,12 @@ built_in_models: List[KilnModel] = [
313
434
  friendly_name="Gemma 2 27B",
314
435
  supports_structured_output=False,
315
436
  providers=[
316
- # TODO: enable once tests update to check if model is available
317
- # KilnModelProvider(
318
- # name=ModelProviderName.ollama,
319
- # provider_options={
320
- # "model": "gemma2:27b",
321
- # },
322
- # ),
437
+ KilnModelProvider(
438
+ name=ModelProviderName.ollama,
439
+ provider_options={
440
+ "model": "gemma2:27b",
441
+ },
442
+ ),
323
443
  KilnModelProvider(
324
444
  name=ModelProviderName.openrouter,
325
445
  provider_options={"model": "google/gemma-2-27b-it"},
@@ -417,7 +537,9 @@ def check_provider_warnings(provider_name: ModelProviderName):
417
537
  raise ValueError(warning_check.message)
418
538
 
419
539
 
420
- def langchain_model_from(name: str, provider_name: str | None = None) -> BaseChatModel:
540
+ async def langchain_model_from(
541
+ name: str, provider_name: str | None = None
542
+ ) -> BaseChatModel:
421
543
  """
422
544
  Creates a LangChain chat model instance for the specified model and provider.
423
545
 
@@ -476,7 +598,23 @@ def langchain_model_from(name: str, provider_name: str | None = None) -> BaseCha
476
598
  **provider.provider_options,
477
599
  )
478
600
  elif provider.name == ModelProviderName.ollama:
479
- return ChatOllama(**provider.provider_options, base_url=ollama_base_url())
601
+ # Ollama model naming is pretty flexible. We try a few versions of the model name
602
+ potential_model_names = []
603
+ if "model" in provider.provider_options:
604
+ potential_model_names.append(provider.provider_options["model"])
605
+ if "model_aliases" in provider.provider_options:
606
+ potential_model_names.extend(provider.provider_options["model_aliases"])
607
+
608
+ # Get the list of models Ollama supports
609
+ ollama_connection = await get_ollama_connection()
610
+ if ollama_connection is None:
611
+ raise ValueError("Failed to connect to Ollama. Ensure Ollama is running.")
612
+
613
+ for model_name in potential_model_names:
614
+ if ollama_model_supported(ollama_connection, model_name):
615
+ return ChatOllama(model=model_name, base_url=ollama_base_url())
616
+
617
+ raise ValueError(f"Model {name} not installed on Ollama")
480
618
  elif provider.name == ModelProviderName.openrouter:
481
619
  api_key = Config.shared().open_router_api_key
482
620
  base_url = getenv("OPENROUTER_BASE_URL") or "https://openrouter.ai/api/v1"
@@ -519,3 +657,67 @@ async def ollama_online() -> bool:
519
657
  except httpx.RequestError:
520
658
  return False
521
659
  return True
660
+
661
+
662
+ class OllamaConnection(BaseModel):
663
+ message: str
664
+ models: List[str]
665
+
666
+
667
+ # Parse the Ollama /api/tags response
668
+ def parse_ollama_tags(tags: Any) -> OllamaConnection | None:
669
+ # Build a list of models we support for Ollama from the built-in model list
670
+ supported_ollama_models = [
671
+ provider.provider_options["model"]
672
+ for model in built_in_models
673
+ for provider in model.providers
674
+ if provider.name == ModelProviderName.ollama
675
+ ]
676
+ # Append model_aliases to supported_ollama_models
677
+ supported_ollama_models.extend(
678
+ [
679
+ alias
680
+ for model in built_in_models
681
+ for provider in model.providers
682
+ for alias in provider.provider_options.get("model_aliases", [])
683
+ ]
684
+ )
685
+
686
+ if "models" in tags:
687
+ models = tags["models"]
688
+ if isinstance(models, list):
689
+ model_names = [model["model"] for model in models]
690
+ print(f"model_names: {model_names}")
691
+ available_supported_models = [
692
+ model
693
+ for model in model_names
694
+ if model in supported_ollama_models
695
+ or model in [f"{m}:latest" for m in supported_ollama_models]
696
+ ]
697
+ if available_supported_models:
698
+ return OllamaConnection(
699
+ message="Ollama connected",
700
+ models=available_supported_models,
701
+ )
702
+
703
+ return OllamaConnection(
704
+ message="Ollama is running, but no supported models are installed. Install one or more supported model, like 'ollama pull phi3.5'.",
705
+ models=[],
706
+ )
707
+
708
+
709
+ async def get_ollama_connection() -> OllamaConnection | None:
710
+ """
711
+ Gets the connection status for Ollama.
712
+ """
713
+ try:
714
+ tags = requests.get(ollama_base_url() + "/api/tags", timeout=5).json()
715
+
716
+ except Exception:
717
+ return None
718
+
719
+ return parse_ollama_tags(tags)
720
+
721
+
722
+ def ollama_model_supported(conn: OllamaConnection, model_name: str) -> bool:
723
+ return model_name in conn.models or f"{model_name}:latest" in conn.models
@@ -1,10 +1,14 @@
1
+ import json
1
2
  from unittest.mock import patch
2
3
 
3
4
  import pytest
4
5
 
5
6
  from kiln_ai.adapters.ml_model_list import (
6
7
  ModelProviderName,
8
+ OllamaConnection,
7
9
  check_provider_warnings,
10
+ ollama_model_supported,
11
+ parse_ollama_tags,
8
12
  provider_name_from_id,
9
13
  provider_warnings,
10
14
  )
@@ -97,3 +101,25 @@ def test_provider_name_from_id_case_sensitivity():
97
101
  )
98
102
  def test_provider_name_from_id_parametrized(provider_id, expected_name):
99
103
  assert provider_name_from_id(provider_id) == expected_name
104
+
105
+
106
+ def test_parse_ollama_tags_no_models():
107
+ json_response = '{"models":[{"name":"phi3.5:latest","model":"phi3.5:latest","modified_at":"2024-10-02T12:04:35.191519822-04:00","size":2176178843,"digest":"61819fb370a3c1a9be6694869331e5f85f867a079e9271d66cb223acb81d04ba","details":{"parent_model":"","format":"gguf","family":"phi3","families":["phi3"],"parameter_size":"3.8B","quantization_level":"Q4_0"}},{"name":"gemma2:2b","model":"gemma2:2b","modified_at":"2024-09-09T16:46:38.64348929-04:00","size":1629518495,"digest":"8ccf136fdd5298f3ffe2d69862750ea7fb56555fa4d5b18c04e3fa4d82ee09d7","details":{"parent_model":"","format":"gguf","family":"gemma2","families":["gemma2"],"parameter_size":"2.6B","quantization_level":"Q4_0"}},{"name":"llama3.1:latest","model":"llama3.1:latest","modified_at":"2024-09-01T17:19:43.481523695-04:00","size":4661230720,"digest":"f66fc8dc39ea206e03ff6764fcc696b1b4dfb693f0b6ef751731dd4e6269046e","details":{"parent_model":"","format":"gguf","family":"llama","families":["llama"],"parameter_size":"8.0B","quantization_level":"Q4_0"}}]}'
108
+ tags = json.loads(json_response)
109
+ print(json.dumps(tags, indent=2))
110
+ conn = parse_ollama_tags(tags)
111
+ assert "phi3.5:latest" in conn.models
112
+ assert "gemma2:2b" in conn.models
113
+ assert "llama3.1:latest" in conn.models
114
+
115
+
116
+ def test_ollama_model_supported():
117
+ conn = OllamaConnection(
118
+ models=["phi3.5:latest", "gemma2:2b", "llama3.1:latest"], message="Connected"
119
+ )
120
+ assert ollama_model_supported(conn, "phi3.5:latest")
121
+ assert ollama_model_supported(conn, "phi3.5")
122
+ assert ollama_model_supported(conn, "gemma2:2b")
123
+ assert ollama_model_supported(conn, "llama3.1:latest")
124
+ assert ollama_model_supported(conn, "llama3.1")
125
+ assert not ollama_model_supported(conn, "unknown_model")
@@ -16,9 +16,25 @@ async def test_groq(tmp_path):
16
16
  await run_simple_test(tmp_path, "llama_3_1_8b", "groq")
17
17
 
18
18
 
19
+ @pytest.mark.parametrize(
20
+ "model_name",
21
+ [
22
+ "llama_3_1_8b",
23
+ "llama_3_1_70b",
24
+ "gemini_1_5_pro",
25
+ "gemini_1_5_flash",
26
+ "gemini_1_5_flash_8b",
27
+ "nemotron_70b",
28
+ "llama_3_2_3b",
29
+ "llama_3_2_11b",
30
+ "llama_3_2_90b",
31
+ "claude_3_5_haiku",
32
+ "claude_3_5_sonnet",
33
+ ],
34
+ )
19
35
  @pytest.mark.paid
20
- async def test_openrouter(tmp_path):
21
- await run_simple_test(tmp_path, "llama_3_1_8b", "openrouter")
36
+ async def test_openrouter(tmp_path, model_name):
37
+ await run_simple_test(tmp_path, model_name, "openrouter")
22
38
 
23
39
 
24
40
  @pytest.mark.ollama
@@ -15,19 +15,21 @@ from kiln_ai.adapters.ml_model_list import (
15
15
  from kiln_ai.datamodel.test_json_schema import json_joke_schema, json_triangle_schema
16
16
 
17
17
 
18
+ @pytest.mark.parametrize(
19
+ "model_name,provider",
20
+ [
21
+ ("llama_3_1_8b", "groq"),
22
+ ("mistral_nemo", "openrouter"),
23
+ ("llama_3_1_70b", "amazon_bedrock"),
24
+ ("claude_3_5_sonnet", "openrouter"),
25
+ ("gemini_1_5_pro", "openrouter"),
26
+ ("gemini_1_5_flash", "openrouter"),
27
+ ("gemini_1_5_flash_8b", "openrouter"),
28
+ ],
29
+ )
18
30
  @pytest.mark.paid
19
- async def test_structured_output_groq(tmp_path):
20
- await run_structured_output_test(tmp_path, "llama_3_1_8b", "groq")
21
-
22
-
23
- @pytest.mark.paid
24
- async def test_structured_output_openrouter(tmp_path):
25
- await run_structured_output_test(tmp_path, "mistral_nemo", "openrouter")
26
-
27
-
28
- @pytest.mark.paid
29
- async def test_structured_output_bedrock(tmp_path):
30
- await run_structured_output_test(tmp_path, "llama_3_1_70b", "amazon_bedrock")
31
+ async def test_structured_output(tmp_path, model_name, provider):
32
+ await run_structured_output_test(tmp_path, model_name, provider)
31
33
 
32
34
 
33
35
  @pytest.mark.ollama
@@ -39,16 +41,17 @@ async def test_structured_output_ollama_phi(tmp_path):
39
41
  await run_structured_output_test(tmp_path, "phi_3_5", "ollama")
40
42
 
41
43
 
42
- @pytest.mark.ollama
44
+ @pytest.mark.paid
43
45
  async def test_structured_output_gpt_4o_mini(tmp_path):
44
46
  await run_structured_output_test(tmp_path, "gpt_4o_mini", "openai")
45
47
 
46
48
 
49
+ @pytest.mark.parametrize("model_name", ["llama_3_1_8b"])
47
50
  @pytest.mark.ollama
48
- async def test_structured_output_ollama_llama(tmp_path):
51
+ async def test_structured_output_ollama_llama(tmp_path, model_name):
49
52
  if not await ollama_online():
50
53
  pytest.skip("Ollama API not running. Expect it running on localhost:11434")
51
- await run_structured_output_test(tmp_path, "llama_3_1_8b", "ollama")
54
+ await run_structured_output_test(tmp_path, model_name, "ollama")
52
55
 
53
56
 
54
57
  class MockAdapter(BaseAdapter):
@@ -105,6 +108,7 @@ async def test_mock_unstructred_response(tmp_path):
105
108
  @pytest.mark.paid
106
109
  @pytest.mark.ollama
107
110
  async def test_all_built_in_models_structured_output(tmp_path):
111
+ errors = []
108
112
  for model in built_in_models:
109
113
  if not model.supports_structured_output:
110
114
  print(
@@ -121,7 +125,10 @@ async def test_all_built_in_models_structured_output(tmp_path):
121
125
  print(f"Running {model.name} {provider.name}")
122
126
  await run_structured_output_test(tmp_path, model.name, provider.name)
123
127
  except Exception as e:
124
- raise RuntimeError(f"Error running {model.name} {provider}") from e
128
+ print(f"Error running {model.name} {provider.name}")
129
+ errors.append(f"{model.name} {provider.name}: {e}")
130
+ if len(errors) > 0:
131
+ raise RuntimeError(f"Errors: {errors}")
125
132
 
126
133
 
127
134
  def build_structured_output_test_task(tmp_path: Path):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: kiln-ai
3
- Version: 0.5.3
3
+ Version: 0.5.5
4
4
  Summary: Kiln AI
5
5
  Project-URL: Homepage, https://getkiln.ai
6
6
  Project-URL: Repository, https://github.com/Kiln-AI/kiln
@@ -9,9 +9,11 @@ Project-URL: Issues, https://github.com/Kiln-AI/kiln/issues
9
9
  Author-email: "Steve Cosman, Chesterfield Laboratories Inc" <scosman@users.noreply.github.com>
10
10
  License-File: LICENSE.txt
11
11
  Classifier: Intended Audience :: Developers
12
+ Classifier: License :: OSI Approved :: MIT License
12
13
  Classifier: Programming Language :: Python :: 3.10
13
14
  Classifier: Programming Language :: Python :: 3.11
14
15
  Classifier: Programming Language :: Python :: 3.12
16
+ Classifier: Programming Language :: Python :: 3.13
15
17
  Requires-Python: >=3.10
16
18
  Requires-Dist: coverage>=7.6.4
17
19
  Requires-Dist: jsonschema>=4.23.0
@@ -29,6 +31,12 @@ Description-Content-Type: text/markdown
29
31
 
30
32
  # kiln_ai
31
33
 
34
+ <p align="center">
35
+ <picture>
36
+ <img width="205" alt="Kiln AI Logo" src="https://github.com/user-attachments/assets/5fbcbdf7-1feb-45c9-bd73-99a46dd0a47f">
37
+ </picture>
38
+ </p>
39
+
32
40
  [![PyPI - Version](https://img.shields.io/pypi/v/kiln-ai.svg?logo=pypi&label=PyPI&logoColor=gold)](https://pypi.org/project/kiln-ai)
33
41
  [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/kiln-ai.svg)](https://pypi.org/project/kiln-ai)
34
42
  [![Docs](https://img.shields.io/badge/docs-pdoc-blue)](https://kiln-ai.github.io/Kiln/kiln_core_docs/index.html)
@@ -41,12 +49,41 @@ Description-Content-Type: text/markdown
41
49
  pip install kiln_ai
42
50
  ```
43
51
 
44
- ## About Kiln AI
52
+ ## About
53
+
54
+ This package is the Kiln AI core library. There is also a separate desktop application and server package. Learn more about Kiln AI at [getkiln.ai](https://getkiln.ai)
55
+
56
+ - Github: [github.com/Kiln-AI/kiln](https://github.com/Kiln-AI/kiln)
57
+ - Core Library Docs: [https://kiln-ai.github.io/Kiln/kiln_core_docs/index.html](https://kiln-ai.github.io/Kiln/kiln_core_docs/index.html)
58
+
59
+ ## Quick Start
45
60
 
46
- Learn more about Kiln AI at [getkiln.ai](https://getkiln.ai)
61
+ ```python
62
+ from kiln_ai.datamodel import Project
47
63
 
48
- This package is the Kiln AI core library. There is also a separate desktop application and server package.
64
+ print("Reading Kiln project")
65
+ project = Project.load_from_file("path/to/project.kiln")
66
+ print("Project: ", project.name, " - ", project.description)
49
67
 
50
- Github: [github.com/Kiln-AI/kiln](https://github.com/Kiln-AI/kiln)
68
+ task = project.tasks()[0]
69
+ print("Task: ", task.name, " - ", task.description)
70
+ print("Total dataset size:", len(task.runs()))
51
71
 
52
- Docs: [https://kiln-ai.github.io/Kiln/kiln_core_docs/index.html](https://kiln-ai.github.io/Kiln/kiln_core_docs/index.html)
72
+ # ... app specific code using the typed kiln datamodel
73
+
74
+ # Alternatively, load data into pandas or a similar tool:
75
+ import glob
76
+ import json
77
+ import pandas as pd
78
+ from pathlib import Path
79
+
80
+ dataitem_glob = str(task.path.parent) + "/runs/*/task_run.kiln"
81
+
82
+ dfs = []
83
+ for file in glob.glob(dataitem_glob):
84
+ js = json.loads(Path(file).read_text())
85
+ df = pd.json_normalize(js)
86
+ dfs.append(df)
87
+ final_df = pd.concat(dfs, ignore_index=True)
88
+ print(final_df)
89
+ ```
@@ -1,15 +1,15 @@
1
1
  kiln_ai/__init__.py,sha256=Sc4z8LRVFMwJUoc_DPVUriSXTZ6PO9MaJ80PhRbKyB8,34
2
2
  kiln_ai/adapters/__init__.py,sha256=3NC1lE_Sg1bF4IsKCoUgje2GL0IwTd1atw1BcDLI8IA,883
3
3
  kiln_ai/adapters/base_adapter.py,sha256=xXCISAJHaPCYHad28CS0wZEUlx711FZ_6AwW4rJx4jk,6688
4
- kiln_ai/adapters/langchain_adapters.py,sha256=Fo7w7hWdkxuuvxoNZhcGE25tOS6ObzhKEUKGszzPdtk,4929
5
- kiln_ai/adapters/ml_model_list.py,sha256=J13pDFp6UwTd7sa_kZyesRaqR1rin4U1iYMc5NYF05Q,17507
4
+ kiln_ai/adapters/langchain_adapters.py,sha256=WNxhuTdjGCsCyqmXJNLe7HJ-MzJ08yagGV-eAHPZF-E,5411
5
+ kiln_ai/adapters/ml_model_list.py,sha256=ueh2jUqCmgGg-jMv0exn5siOU_6p0rGeJs3jy8ZWvuE,23821
6
6
  kiln_ai/adapters/prompt_builders.py,sha256=nfZnEr1E30ZweQhEzIP21rNrL2Or1ILajyX8gU3B7w0,7796
7
7
  kiln_ai/adapters/test_langchain_adapter.py,sha256=_xHpVAkkoGh0PRO3BFFqvVj95SVtYZPOdFbYGYfzvQ0,1876
8
- kiln_ai/adapters/test_ml_model_list.py,sha256=BNuJSIegMMLzcICDR49qLFm7ezSl188LE4-W98c73tA,2786
9
- kiln_ai/adapters/test_prompt_adaptors.py,sha256=TXfSLfOHcg9EJINLfyJDQ-WcMw4He8ab4k-fGeryJcY,6033
8
+ kiln_ai/adapters/test_ml_model_list.py,sha256=XHbwEFFb7WvZ6UkArqIiQ_yhS_urezHtgvJOSnaricY,4660
9
+ kiln_ai/adapters/test_prompt_adaptors.py,sha256=W3TeacWs5iPA3BE1OJ6VkIftrHWzXd3edBoUgFaQAek,6389
10
10
  kiln_ai/adapters/test_prompt_builders.py,sha256=WmTR59tnKnKQ5gnX1X9EqvEUdQr0PQ8OvadYtRQR5sQ,11483
11
11
  kiln_ai/adapters/test_saving_adapter_results.py,sha256=tQvpLawo8mR2scPwmRCIz9Sp0ZkerS3kVJKBzlcjwRE,6041
12
- kiln_ai/adapters/test_structured_output.py,sha256=Okl6kLaAEKOuy1UBvQuiM5LGmJJi2aPB8sQR4bzIyIA,8755
12
+ kiln_ai/adapters/test_structured_output.py,sha256=Z9A2R-TC-2atsdr8sGVGDlJhfa7uytW8Xi8PKBdEEAw,9033
13
13
  kiln_ai/adapters/repair/__init__.py,sha256=dOO9MEpEhjiwzDVFg3MNfA2bKMPlax9iekDatpTkX8E,217
14
14
  kiln_ai/adapters/repair/repair_task.py,sha256=VXvX1l9AYDE_GV0i3S_vPThltJoQlCFVCCHV9m-QA7k,3297
15
15
  kiln_ai/adapters/repair/test_repair_task.py,sha256=12PHb4SgBvVdLUzjZz31M8OTa8D8QjHD0Du4s7ij-i8,7819
@@ -27,7 +27,7 @@ kiln_ai/utils/__init__.py,sha256=PTD0MwBCKAMIOGsTAwsFaJOusTJJoRFTfOGqRvCaU-E,142
27
27
  kiln_ai/utils/config.py,sha256=jXUB8lwFkxLNEaizwIsoeFLg1BwjWr39-5KdEGF37Bg,5424
28
28
  kiln_ai/utils/formatting.py,sha256=VtB9oag0lOGv17dwT7OPX_3HzBfaU9GsLH-iLete0yM,97
29
29
  kiln_ai/utils/test_config.py,sha256=lbN0NhgKPEZ0idaS-zTn6mWsSAV6omo32JcIy05h2-M,7411
30
- kiln_ai-0.5.3.dist-info/METADATA,sha256=LlMBr0-VSmD3DyCwuCB8Y2ao248HMS-6Gs-3Epfwui0,1915
31
- kiln_ai-0.5.3.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
32
- kiln_ai-0.5.3.dist-info/licenses/LICENSE.txt,sha256=-AhuIX-CMdNGJNj74C29e9cKKmsh-1PBPINCsNvwAeg,82
33
- kiln_ai-0.5.3.dist-info/RECORD,,
30
+ kiln_ai-0.5.5.dist-info/METADATA,sha256=rD2UKYBIVHUrfsPP7-BhaUXGdLXVkcJDIUs8i75GSX8,3005
31
+ kiln_ai-0.5.5.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
32
+ kiln_ai-0.5.5.dist-info/licenses/LICENSE.txt,sha256=_NA5pnTYgRRr4qH6lE3X-TuZJ8iRcMUi5ASoGr-lEx8,1209
33
+ kiln_ai-0.5.5.dist-info/RECORD,,
@@ -0,0 +1,13 @@
1
+
2
+
3
+ This license applies only to the software in the libs/core directory.
4
+
5
+ =======================================================
6
+
7
+ Copyright 2024 - Chesterfield Laboratories Inc.
8
+
9
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
10
+
11
+ The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
12
+
13
+ THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
@@ -1,4 +0,0 @@
1
-
2
- All rights reserved.
3
-
4
- Copyright (c) Steve Cosman, Chesterfield Laboratories Inc.