kiln-ai 0.7.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

@@ -9,6 +9,7 @@ def adapter_for_task(
9
9
  model_name: str | None = None,
10
10
  provider: str | None = None,
11
11
  prompt_builder: BasePromptBuilder | None = None,
12
+ tags: list[str] | None = None,
12
13
  ) -> BaseAdapter:
13
14
  # We use langchain for everything right now, but can add any others here
14
15
  return LangchainAdapter(
@@ -16,4 +17,5 @@ def adapter_for_task(
16
17
  model_name=model_name,
17
18
  provider=provider,
18
19
  prompt_builder=prompt_builder,
20
+ tags=tags,
19
21
  )
@@ -45,12 +45,16 @@ class BaseAdapter(metaclass=ABCMeta):
45
45
  """
46
46
 
47
47
  def __init__(
48
- self, kiln_task: Task, prompt_builder: BasePromptBuilder | None = None
48
+ self,
49
+ kiln_task: Task,
50
+ prompt_builder: BasePromptBuilder | None = None,
51
+ tags: list[str] | None = None,
49
52
  ):
50
53
  self.prompt_builder = prompt_builder or SimplePromptBuilder(kiln_task)
51
54
  self.kiln_task = kiln_task
52
55
  self.output_schema = self.kiln_task.output_json_schema
53
56
  self.input_schema = self.kiln_task.input_json_schema
57
+ self.default_tags = tags
54
58
 
55
59
  async def invoke_returning_raw(
56
60
  self,
@@ -148,6 +152,7 @@ class BaseAdapter(metaclass=ABCMeta):
148
152
  ),
149
153
  ),
150
154
  intermediate_outputs=run_output.intermediate_outputs,
155
+ tags=self.default_tags or [],
151
156
  )
152
157
 
153
158
  exclude_fields = {
@@ -39,8 +39,9 @@ class LangchainAdapter(BaseAdapter):
39
39
  model_name: str | None = None,
40
40
  provider: str | None = None,
41
41
  prompt_builder: BasePromptBuilder | None = None,
42
+ tags: list[str] | None = None,
42
43
  ):
43
- super().__init__(kiln_task, prompt_builder=prompt_builder)
44
+ super().__init__(kiln_task, prompt_builder=prompt_builder, tags=tags)
44
45
  if custom_model is not None:
45
46
  self._model = custom_model
46
47
 
@@ -198,6 +199,9 @@ async def langchain_model_from_provider(
198
199
  if provider.name == ModelProviderName.openai:
199
200
  api_key = Config.shared().open_ai_api_key
200
201
  return ChatOpenAI(**provider.provider_options, openai_api_key=api_key) # type: ignore[arg-type]
202
+ elif provider.name == ModelProviderName.openai_compatible:
203
+ # See provider_tools.py for how base_url, key and other parameters are set
204
+ return ChatOpenAI(**provider.provider_options) # type: ignore[arg-type]
201
205
  elif provider.name == ModelProviderName.groq:
202
206
  api_key = Config.shared().groq_api_key
203
207
  if api_key is None:
@@ -22,6 +22,8 @@ class ModelProviderName(str, Enum):
22
22
  openrouter = "openrouter"
23
23
  fireworks_ai = "fireworks_ai"
24
24
  kiln_fine_tune = "kiln_fine_tune"
25
+ kiln_custom_registry = "kiln_custom_registry"
26
+ openai_compatible = "openai_compatible"
25
27
 
26
28
 
27
29
  class ModelFamily(str, Enum):
@@ -54,6 +56,7 @@ class ModelName(str, Enum):
54
56
  llama_3_2_3b = "llama_3_2_3b"
55
57
  llama_3_2_11b = "llama_3_2_11b"
56
58
  llama_3_2_90b = "llama_3_2_90b"
59
+ llama_3_3_70b = "llama_3_3_70b"
57
60
  gpt_4o_mini = "gpt_4o_mini"
58
61
  gpt_4o = "gpt_4o"
59
62
  phi_3_5 = "phi_3_5"
@@ -502,6 +505,46 @@ built_in_models: List[KilnModel] = [
502
505
  ),
503
506
  ],
504
507
  ),
508
+ # Llama 3.3 70B
509
+ KilnModel(
510
+ family=ModelFamily.llama,
511
+ name=ModelName.llama_3_3_70b,
512
+ friendly_name="Llama 3.3 70B",
513
+ providers=[
514
+ KilnModelProvider(
515
+ name=ModelProviderName.openrouter,
516
+ provider_options={"model": "meta-llama/llama-3.3-70b-instruct"},
517
+ # Openrouter not supporing tools yet. Once they do probably can remove. JSON mode sometimes works, but not consistently.
518
+ supports_structured_output=False,
519
+ supports_data_gen=False,
520
+ adapter_options={
521
+ "langchain": {
522
+ "with_structured_output_options": {"method": "json_mode"}
523
+ }
524
+ },
525
+ ),
526
+ KilnModelProvider(
527
+ name=ModelProviderName.groq,
528
+ supports_structured_output=True,
529
+ supports_data_gen=True,
530
+ provider_options={"model": "llama-3.3-70b-versatile"},
531
+ ),
532
+ KilnModelProvider(
533
+ name=ModelProviderName.ollama,
534
+ provider_options={"model": "llama3.3"},
535
+ ),
536
+ KilnModelProvider(
537
+ name=ModelProviderName.fireworks_ai,
538
+ # Finetuning not live yet
539
+ # provider_finetune_id="accounts/fireworks/models/llama-v3p3-70b-instruct",
540
+ supports_structured_output=True,
541
+ supports_data_gen=True,
542
+ provider_options={
543
+ "model": "accounts/fireworks/models/llama-v3p3-70b-instruct"
544
+ },
545
+ ),
546
+ ],
547
+ ),
505
548
  # Phi 3.5
506
549
  KilnModel(
507
550
  family=ModelFamily.phi,
@@ -598,18 +641,6 @@ built_in_models: List[KilnModel] = [
598
641
  name=ModelName.mixtral_8x7b,
599
642
  friendly_name="Mixtral 8x7B",
600
643
  providers=[
601
- KilnModelProvider(
602
- name=ModelProviderName.fireworks_ai,
603
- provider_options={
604
- "model": "accounts/fireworks/models/mixtral-8x7b-instruct-hf",
605
- },
606
- provider_finetune_id="accounts/fireworks/models/mixtral-8x7b-instruct-hf",
607
- adapter_options={
608
- "langchain": {
609
- "with_structured_output_options": {"method": "json_mode"}
610
- }
611
- },
612
- ),
613
644
  KilnModelProvider(
614
645
  name=ModelProviderName.openrouter,
615
646
  provider_options={"model": "mistralai/mixtral-8x7b-instruct"},
@@ -6,6 +6,7 @@ import requests
6
6
  from pydantic import BaseModel, Field
7
7
 
8
8
  from kiln_ai.adapters.ml_model_list import ModelProviderName, built_in_models
9
+ from kiln_ai.utils.config import Config
9
10
 
10
11
 
11
12
  def ollama_base_url() -> str:
@@ -16,9 +17,9 @@ def ollama_base_url() -> str:
16
17
  The base URL to use for Ollama API calls, using environment variable if set
17
18
  or falling back to localhost default
18
19
  """
19
- env_base_url = os.getenv("OLLAMA_BASE_URL")
20
- if env_base_url is not None:
21
- return env_base_url
20
+ config_base_url = Config.shared().ollama_base_url
21
+ if config_base_url:
22
+ return config_base_url
22
23
  return "http://localhost:11434"
23
24
 
24
25
 
@@ -11,6 +11,7 @@ from kiln_ai.adapters.ml_model_list import (
11
11
  from kiln_ai.adapters.ollama_tools import (
12
12
  get_ollama_connection,
13
13
  )
14
+ from kiln_ai.datamodel import Finetune, Task
14
15
  from kiln_ai.datamodel.registry import project_from_id
15
16
 
16
17
  from ..utils.config import Config
@@ -107,10 +108,18 @@ async def kiln_model_provider_from(
107
108
  if provider_name == ModelProviderName.kiln_fine_tune:
108
109
  return finetune_provider_model(name)
109
110
 
111
+ if provider_name == ModelProviderName.openai_compatible:
112
+ return openai_compatible_provider_model(name)
113
+
110
114
  built_in_model = await builtin_model_from(name, provider_name)
111
115
  if built_in_model:
112
116
  return built_in_model
113
117
 
118
+ # For custom registry, get the provider name and model name from the model id
119
+ if provider_name == ModelProviderName.kiln_custom_registry:
120
+ provider_name = name.split("::", 1)[0]
121
+ name = name.split("::", 1)[1]
122
+
114
123
  # Custom/untested model. Set untested, and build a ModelProvider at runtime
115
124
  if provider_name is None:
116
125
  raise ValueError("Provider name is required for custom models")
@@ -130,6 +139,45 @@ async def kiln_model_provider_from(
130
139
  finetune_cache: dict[str, KilnModelProvider] = {}
131
140
 
132
141
 
142
+ def openai_compatible_provider_model(
143
+ model_id: str,
144
+ ) -> KilnModelProvider:
145
+ try:
146
+ openai_provider_name, model_id = model_id.split("::")
147
+ except Exception:
148
+ raise ValueError(f"Invalid openai compatible model ID: {model_id}")
149
+
150
+ openai_compatible_providers = Config.shared().openai_compatible_providers or []
151
+ provider = next(
152
+ filter(
153
+ lambda p: p.get("name") == openai_provider_name, openai_compatible_providers
154
+ ),
155
+ None,
156
+ )
157
+ if provider is None:
158
+ raise ValueError(f"OpenAI compatible provider {openai_provider_name} not found")
159
+
160
+ # API key optional some providers don't use it
161
+ api_key = provider.get("api_key")
162
+ base_url = provider.get("base_url")
163
+ if base_url is None:
164
+ raise ValueError(
165
+ f"OpenAI compatible provider {openai_provider_name} has no base URL"
166
+ )
167
+
168
+ return KilnModelProvider(
169
+ name=ModelProviderName.openai_compatible,
170
+ provider_options={
171
+ "model": model_id,
172
+ "api_key": api_key,
173
+ "openai_api_base": base_url,
174
+ },
175
+ supports_structured_output=False,
176
+ supports_data_gen=False,
177
+ untested_model=True,
178
+ )
179
+
180
+
133
181
  def finetune_provider_model(
134
182
  model_id: str,
135
183
  ) -> KilnModelProvider:
@@ -143,10 +191,10 @@ def finetune_provider_model(
143
191
  project = project_from_id(project_id)
144
192
  if project is None:
145
193
  raise ValueError(f"Project {project_id} not found")
146
- task = next((t for t in project.tasks() if t.id == task_id), None)
194
+ task = Task.from_id_and_parent_path(task_id, project.path)
147
195
  if task is None:
148
196
  raise ValueError(f"Task {task_id} not found")
149
- fine_tune = next((f for f in task.finetunes() if f.id == fine_tune_id), None)
197
+ fine_tune = Finetune.from_id_and_parent_path(fine_tune_id, task.path)
150
198
  if fine_tune is None:
151
199
  raise ValueError(f"Fine tune {fine_tune_id} not found")
152
200
  if fine_tune.fine_tune_model_id is None:
@@ -220,6 +268,10 @@ def provider_name_from_id(id: str) -> str:
220
268
  return "Fine Tuned Models"
221
269
  case ModelProviderName.fireworks_ai:
222
270
  return "Fireworks AI"
271
+ case ModelProviderName.kiln_custom_registry:
272
+ return "Custom Models"
273
+ case ModelProviderName.openai_compatible:
274
+ return "OpenAI Compatible"
223
275
  case _:
224
276
  # triggers pyright warning if I miss a case
225
277
  raise_exhaustive_error(enum_id)
@@ -233,6 +285,7 @@ def provider_options_for_custom_model(
233
285
  """
234
286
  Generated model provider options for a custom model. Each has their own format/options.
235
287
  """
288
+
236
289
  if provider_name not in ModelProviderName.__members__:
237
290
  raise ValueError(f"Invalid provider name: {provider_name}")
238
291
 
@@ -249,10 +302,18 @@ def provider_options_for_custom_model(
249
302
  | ModelProviderName.groq
250
303
  ):
251
304
  return {"model": model_name}
305
+ case ModelProviderName.kiln_custom_registry:
306
+ raise ValueError(
307
+ "Custom models from registry should be parsed into provider/model before calling this."
308
+ )
252
309
  case ModelProviderName.kiln_fine_tune:
253
310
  raise ValueError(
254
311
  "Fine tuned models should populate provider options via another path"
255
312
  )
313
+ case ModelProviderName.openai_compatible:
314
+ raise ValueError(
315
+ "OpenAI compatible models should populate provider options via another path"
316
+ )
256
317
  case _:
257
318
  # triggers pyright warning if I miss a case
258
319
  raise_exhaustive_error(enum_id)
@@ -43,8 +43,10 @@ feedback describing what should be improved. Your job is to understand the evalu
43
43
  @classmethod
44
44
  def _original_prompt(cls, run: TaskRun, task: Task) -> str:
45
45
  prompt_builder_class: Type[BasePromptBuilder] | None = None
46
- prompt_builder_name = run.output.source.properties.get(
47
- "prompt_builder_name", None
46
+ prompt_builder_name = (
47
+ run.output.source.properties.get("prompt_builder_name", None)
48
+ if run.output.source
49
+ else None
48
50
  )
49
51
  if prompt_builder_name is not None and isinstance(prompt_builder_name, str):
50
52
  prompt_builder_class = prompt_builder_registry.get(
@@ -1,12 +1,20 @@
1
+ import os
1
2
  from unittest.mock import AsyncMock, MagicMock, patch
2
3
 
4
+ import pytest
5
+ from langchain_aws import ChatBedrockConverse
3
6
  from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
7
+ from langchain_fireworks import ChatFireworks
4
8
  from langchain_groq import ChatGroq
9
+ from langchain_ollama import ChatOllama
10
+ from langchain_openai import ChatOpenAI
5
11
 
6
12
  from kiln_ai.adapters.langchain_adapters import (
7
13
  LangchainAdapter,
8
14
  get_structured_output_options,
15
+ langchain_model_from_provider,
9
16
  )
17
+ from kiln_ai.adapters.ml_model_list import KilnModelProvider, ModelProviderName
10
18
  from kiln_ai.adapters.prompt_builders import SimpleChainOfThoughtPromptBuilder
11
19
  from kiln_ai.adapters.test_prompt_adaptors import build_test_task
12
20
 
@@ -150,3 +158,178 @@ async def test_get_structured_output_options():
150
158
  ):
151
159
  options = await get_structured_output_options("model_name", "provider")
152
160
  assert options == {}
161
+
162
+
163
+ @pytest.mark.asyncio
164
+ async def test_langchain_model_from_provider_openai():
165
+ provider = KilnModelProvider(
166
+ name=ModelProviderName.openai, provider_options={"model": "gpt-4"}
167
+ )
168
+
169
+ with patch("kiln_ai.adapters.langchain_adapters.Config.shared") as mock_config:
170
+ mock_config.return_value.open_ai_api_key = "test_key"
171
+ model = await langchain_model_from_provider(provider, "gpt-4")
172
+ assert isinstance(model, ChatOpenAI)
173
+ assert model.model_name == "gpt-4"
174
+
175
+
176
+ @pytest.mark.asyncio
177
+ async def test_langchain_model_from_provider_groq():
178
+ provider = KilnModelProvider(
179
+ name=ModelProviderName.groq, provider_options={"model": "mixtral-8x7b"}
180
+ )
181
+
182
+ with patch("kiln_ai.adapters.langchain_adapters.Config.shared") as mock_config:
183
+ mock_config.return_value.groq_api_key = "test_key"
184
+ model = await langchain_model_from_provider(provider, "mixtral-8x7b")
185
+ assert isinstance(model, ChatGroq)
186
+ assert model.model_name == "mixtral-8x7b"
187
+
188
+
189
+ @pytest.mark.asyncio
190
+ async def test_langchain_model_from_provider_bedrock():
191
+ provider = KilnModelProvider(
192
+ name=ModelProviderName.amazon_bedrock,
193
+ provider_options={"model": "anthropic.claude-v2", "region_name": "us-east-1"},
194
+ )
195
+
196
+ with patch("kiln_ai.adapters.langchain_adapters.Config.shared") as mock_config:
197
+ mock_config.return_value.bedrock_access_key = "test_access"
198
+ mock_config.return_value.bedrock_secret_key = "test_secret"
199
+ model = await langchain_model_from_provider(provider, "anthropic.claude-v2")
200
+ assert isinstance(model, ChatBedrockConverse)
201
+ assert os.environ.get("AWS_ACCESS_KEY_ID") == "test_access"
202
+ assert os.environ.get("AWS_SECRET_ACCESS_KEY") == "test_secret"
203
+
204
+
205
+ @pytest.mark.asyncio
206
+ async def test_langchain_model_from_provider_fireworks():
207
+ provider = KilnModelProvider(
208
+ name=ModelProviderName.fireworks_ai, provider_options={"model": "mixtral-8x7b"}
209
+ )
210
+
211
+ with patch("kiln_ai.adapters.langchain_adapters.Config.shared") as mock_config:
212
+ mock_config.return_value.fireworks_api_key = "test_key"
213
+ model = await langchain_model_from_provider(provider, "mixtral-8x7b")
214
+ assert isinstance(model, ChatFireworks)
215
+
216
+
217
+ @pytest.mark.asyncio
218
+ async def test_langchain_model_from_provider_ollama():
219
+ provider = KilnModelProvider(
220
+ name=ModelProviderName.ollama,
221
+ provider_options={"model": "llama2", "model_aliases": ["llama2-uncensored"]},
222
+ )
223
+
224
+ mock_connection = MagicMock()
225
+ with (
226
+ patch(
227
+ "kiln_ai.adapters.langchain_adapters.get_ollama_connection",
228
+ return_value=AsyncMock(return_value=mock_connection),
229
+ ),
230
+ patch(
231
+ "kiln_ai.adapters.langchain_adapters.ollama_model_installed",
232
+ return_value=True,
233
+ ),
234
+ patch(
235
+ "kiln_ai.adapters.langchain_adapters.ollama_base_url",
236
+ return_value="http://localhost:11434",
237
+ ),
238
+ ):
239
+ model = await langchain_model_from_provider(provider, "llama2")
240
+ assert isinstance(model, ChatOllama)
241
+ assert model.model == "llama2"
242
+
243
+
244
+ @pytest.mark.asyncio
245
+ async def test_langchain_model_from_provider_invalid():
246
+ provider = KilnModelProvider.model_construct(
247
+ name="invalid_provider", provider_options={}
248
+ )
249
+
250
+ with pytest.raises(ValueError, match="Invalid model or provider"):
251
+ await langchain_model_from_provider(provider, "test_model")
252
+
253
+
254
+ @pytest.mark.asyncio
255
+ async def test_langchain_adapter_model_caching(tmp_path):
256
+ task = build_test_task(tmp_path)
257
+ custom_model = ChatGroq(model="mixtral-8x7b", groq_api_key="test")
258
+
259
+ adapter = LangchainAdapter(kiln_task=task, custom_model=custom_model)
260
+
261
+ # First call should return the cached model
262
+ model1 = await adapter.model()
263
+ assert model1 is custom_model
264
+
265
+ # Second call should return the same cached instance
266
+ model2 = await adapter.model()
267
+ assert model2 is model1
268
+
269
+
270
+ @pytest.mark.asyncio
271
+ async def test_langchain_adapter_model_structured_output(tmp_path):
272
+ task = build_test_task(tmp_path)
273
+ task.output_json_schema = """
274
+ {
275
+ "type": "object",
276
+ "properties": {
277
+ "count": {"type": "integer"}
278
+ }
279
+ }
280
+ """
281
+
282
+ mock_model = MagicMock()
283
+ mock_model.with_structured_output = MagicMock(return_value="structured_model")
284
+
285
+ adapter = LangchainAdapter(
286
+ kiln_task=task, model_name="test_model", provider="test_provider"
287
+ )
288
+
289
+ with (
290
+ patch(
291
+ "kiln_ai.adapters.langchain_adapters.langchain_model_from",
292
+ AsyncMock(return_value=mock_model),
293
+ ),
294
+ patch(
295
+ "kiln_ai.adapters.langchain_adapters.get_structured_output_options",
296
+ AsyncMock(return_value={"option1": "value1"}),
297
+ ),
298
+ ):
299
+ model = await adapter.model()
300
+
301
+ # Verify the model was configured with structured output
302
+ mock_model.with_structured_output.assert_called_once_with(
303
+ {
304
+ "type": "object",
305
+ "properties": {"count": {"type": "integer"}},
306
+ "title": "task_response",
307
+ "description": "A response from the task",
308
+ },
309
+ include_raw=True,
310
+ option1="value1",
311
+ )
312
+ assert model == "structured_model"
313
+
314
+
315
+ @pytest.mark.asyncio
316
+ async def test_langchain_adapter_model_no_structured_output_support(tmp_path):
317
+ task = build_test_task(tmp_path)
318
+ task.output_json_schema = (
319
+ '{"type": "object", "properties": {"count": {"type": "integer"}}}'
320
+ )
321
+
322
+ mock_model = MagicMock()
323
+ # Remove with_structured_output method
324
+ del mock_model.with_structured_output
325
+
326
+ adapter = LangchainAdapter(
327
+ kiln_task=task, model_name="test_model", provider="test_provider"
328
+ )
329
+
330
+ with patch(
331
+ "kiln_ai.adapters.langchain_adapters.langchain_model_from",
332
+ AsyncMock(return_value=mock_model),
333
+ ):
334
+ with pytest.raises(ValueError, match="does not support structured output"):
335
+ await adapter.model()