kiln-ai 0.7.0__py3-none-any.whl → 0.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

@@ -22,6 +22,7 @@ class ModelProviderName(str, Enum):
22
22
  openrouter = "openrouter"
23
23
  fireworks_ai = "fireworks_ai"
24
24
  kiln_fine_tune = "kiln_fine_tune"
25
+ kiln_custom_registry = "kiln_custom_registry"
25
26
 
26
27
 
27
28
  class ModelFamily(str, Enum):
@@ -54,6 +55,7 @@ class ModelName(str, Enum):
54
55
  llama_3_2_3b = "llama_3_2_3b"
55
56
  llama_3_2_11b = "llama_3_2_11b"
56
57
  llama_3_2_90b = "llama_3_2_90b"
58
+ llama_3_3_70b = "llama_3_3_70b"
57
59
  gpt_4o_mini = "gpt_4o_mini"
58
60
  gpt_4o = "gpt_4o"
59
61
  phi_3_5 = "phi_3_5"
@@ -502,6 +504,38 @@ built_in_models: List[KilnModel] = [
502
504
  ),
503
505
  ],
504
506
  ),
507
+ # Llama 3.3 70B
508
+ KilnModel(
509
+ family=ModelFamily.llama,
510
+ name=ModelName.llama_3_3_70b,
511
+ friendly_name="Llama 3.3 70B",
512
+ providers=[
513
+ KilnModelProvider(
514
+ name=ModelProviderName.openrouter,
515
+ provider_options={"model": "meta-llama/llama-3.3-70b-instruct"},
516
+ # Openrouter not supporing tools yet. Once they do probably can remove. JSON mode sometimes works, but not consistently.
517
+ supports_structured_output=False,
518
+ supports_data_gen=False,
519
+ adapter_options={
520
+ "langchain": {
521
+ "with_structured_output_options": {"method": "json_mode"}
522
+ }
523
+ },
524
+ ),
525
+ KilnModelProvider(
526
+ name=ModelProviderName.ollama,
527
+ provider_options={"model": "llama3.3"},
528
+ ),
529
+ KilnModelProvider(
530
+ name=ModelProviderName.fireworks_ai,
531
+ # Finetuning not live yet
532
+ # provider_finetune_id="accounts/fireworks/models/llama-v3p3-70b-instruct",
533
+ provider_options={
534
+ "model": "accounts/fireworks/models/llama-v3p3-70b-instruct"
535
+ },
536
+ ),
537
+ ],
538
+ ),
505
539
  # Phi 3.5
506
540
  KilnModel(
507
541
  family=ModelFamily.phi,
@@ -598,18 +632,6 @@ built_in_models: List[KilnModel] = [
598
632
  name=ModelName.mixtral_8x7b,
599
633
  friendly_name="Mixtral 8x7B",
600
634
  providers=[
601
- KilnModelProvider(
602
- name=ModelProviderName.fireworks_ai,
603
- provider_options={
604
- "model": "accounts/fireworks/models/mixtral-8x7b-instruct-hf",
605
- },
606
- provider_finetune_id="accounts/fireworks/models/mixtral-8x7b-instruct-hf",
607
- adapter_options={
608
- "langchain": {
609
- "with_structured_output_options": {"method": "json_mode"}
610
- }
611
- },
612
- ),
613
635
  KilnModelProvider(
614
636
  name=ModelProviderName.openrouter,
615
637
  provider_options={"model": "mistralai/mixtral-8x7b-instruct"},
@@ -6,6 +6,7 @@ import requests
6
6
  from pydantic import BaseModel, Field
7
7
 
8
8
  from kiln_ai.adapters.ml_model_list import ModelProviderName, built_in_models
9
+ from kiln_ai.utils.config import Config
9
10
 
10
11
 
11
12
  def ollama_base_url() -> str:
@@ -16,9 +17,9 @@ def ollama_base_url() -> str:
16
17
  The base URL to use for Ollama API calls, using environment variable if set
17
18
  or falling back to localhost default
18
19
  """
19
- env_base_url = os.getenv("OLLAMA_BASE_URL")
20
- if env_base_url is not None:
21
- return env_base_url
20
+ config_base_url = Config.shared().ollama_base_url
21
+ if config_base_url:
22
+ return config_base_url
22
23
  return "http://localhost:11434"
23
24
 
24
25
 
@@ -11,6 +11,7 @@ from kiln_ai.adapters.ml_model_list import (
11
11
  from kiln_ai.adapters.ollama_tools import (
12
12
  get_ollama_connection,
13
13
  )
14
+ from kiln_ai.datamodel import Finetune, Task
14
15
  from kiln_ai.datamodel.registry import project_from_id
15
16
 
16
17
  from ..utils.config import Config
@@ -111,6 +112,11 @@ async def kiln_model_provider_from(
111
112
  if built_in_model:
112
113
  return built_in_model
113
114
 
115
+ # For custom registry, get the provider name and model name from the model id
116
+ if provider_name == ModelProviderName.kiln_custom_registry:
117
+ provider_name = name.split("::", 1)[0]
118
+ name = name.split("::", 1)[1]
119
+
114
120
  # Custom/untested model. Set untested, and build a ModelProvider at runtime
115
121
  if provider_name is None:
116
122
  raise ValueError("Provider name is required for custom models")
@@ -143,10 +149,10 @@ def finetune_provider_model(
143
149
  project = project_from_id(project_id)
144
150
  if project is None:
145
151
  raise ValueError(f"Project {project_id} not found")
146
- task = next((t for t in project.tasks() if t.id == task_id), None)
152
+ task = Task.from_id_and_parent_path(task_id, project.path)
147
153
  if task is None:
148
154
  raise ValueError(f"Task {task_id} not found")
149
- fine_tune = next((f for f in task.finetunes() if f.id == fine_tune_id), None)
155
+ fine_tune = Finetune.from_id_and_parent_path(fine_tune_id, task.path)
150
156
  if fine_tune is None:
151
157
  raise ValueError(f"Fine tune {fine_tune_id} not found")
152
158
  if fine_tune.fine_tune_model_id is None:
@@ -220,6 +226,8 @@ def provider_name_from_id(id: str) -> str:
220
226
  return "Fine Tuned Models"
221
227
  case ModelProviderName.fireworks_ai:
222
228
  return "Fireworks AI"
229
+ case ModelProviderName.kiln_custom_registry:
230
+ return "Custom Models"
223
231
  case _:
224
232
  # triggers pyright warning if I miss a case
225
233
  raise_exhaustive_error(enum_id)
@@ -233,6 +241,7 @@ def provider_options_for_custom_model(
233
241
  """
234
242
  Generated model provider options for a custom model. Each has their own format/options.
235
243
  """
244
+
236
245
  if provider_name not in ModelProviderName.__members__:
237
246
  raise ValueError(f"Invalid provider name: {provider_name}")
238
247
 
@@ -249,6 +258,10 @@ def provider_options_for_custom_model(
249
258
  | ModelProviderName.groq
250
259
  ):
251
260
  return {"model": model_name}
261
+ case ModelProviderName.kiln_custom_registry:
262
+ raise ValueError(
263
+ "Custom models from registry should be parsed into provider/model before calling this."
264
+ )
252
265
  case ModelProviderName.kiln_fine_tune:
253
266
  raise ValueError(
254
267
  "Fine tuned models should populate provider options via another path"
@@ -43,8 +43,10 @@ feedback describing what should be improved. Your job is to understand the evalu
43
43
  @classmethod
44
44
  def _original_prompt(cls, run: TaskRun, task: Task) -> str:
45
45
  prompt_builder_class: Type[BasePromptBuilder] | None = None
46
- prompt_builder_name = run.output.source.properties.get(
47
- "prompt_builder_name", None
46
+ prompt_builder_name = (
47
+ run.output.source.properties.get("prompt_builder_name", None)
48
+ if run.output.source
49
+ else None
48
50
  )
49
51
  if prompt_builder_name is not None and isinstance(prompt_builder_name, str):
50
52
  prompt_builder_class = prompt_builder_registry.get(
@@ -1,12 +1,20 @@
1
+ import os
1
2
  from unittest.mock import AsyncMock, MagicMock, patch
2
3
 
4
+ import pytest
5
+ from langchain_aws import ChatBedrockConverse
3
6
  from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
7
+ from langchain_fireworks import ChatFireworks
4
8
  from langchain_groq import ChatGroq
9
+ from langchain_ollama import ChatOllama
10
+ from langchain_openai import ChatOpenAI
5
11
 
6
12
  from kiln_ai.adapters.langchain_adapters import (
7
13
  LangchainAdapter,
8
14
  get_structured_output_options,
15
+ langchain_model_from_provider,
9
16
  )
17
+ from kiln_ai.adapters.ml_model_list import KilnModelProvider, ModelProviderName
10
18
  from kiln_ai.adapters.prompt_builders import SimpleChainOfThoughtPromptBuilder
11
19
  from kiln_ai.adapters.test_prompt_adaptors import build_test_task
12
20
 
@@ -150,3 +158,178 @@ async def test_get_structured_output_options():
150
158
  ):
151
159
  options = await get_structured_output_options("model_name", "provider")
152
160
  assert options == {}
161
+
162
+
163
+ @pytest.mark.asyncio
164
+ async def test_langchain_model_from_provider_openai():
165
+ provider = KilnModelProvider(
166
+ name=ModelProviderName.openai, provider_options={"model": "gpt-4"}
167
+ )
168
+
169
+ with patch("kiln_ai.adapters.langchain_adapters.Config.shared") as mock_config:
170
+ mock_config.return_value.open_ai_api_key = "test_key"
171
+ model = await langchain_model_from_provider(provider, "gpt-4")
172
+ assert isinstance(model, ChatOpenAI)
173
+ assert model.model_name == "gpt-4"
174
+
175
+
176
+ @pytest.mark.asyncio
177
+ async def test_langchain_model_from_provider_groq():
178
+ provider = KilnModelProvider(
179
+ name=ModelProviderName.groq, provider_options={"model": "mixtral-8x7b"}
180
+ )
181
+
182
+ with patch("kiln_ai.adapters.langchain_adapters.Config.shared") as mock_config:
183
+ mock_config.return_value.groq_api_key = "test_key"
184
+ model = await langchain_model_from_provider(provider, "mixtral-8x7b")
185
+ assert isinstance(model, ChatGroq)
186
+ assert model.model_name == "mixtral-8x7b"
187
+
188
+
189
+ @pytest.mark.asyncio
190
+ async def test_langchain_model_from_provider_bedrock():
191
+ provider = KilnModelProvider(
192
+ name=ModelProviderName.amazon_bedrock,
193
+ provider_options={"model": "anthropic.claude-v2", "region_name": "us-east-1"},
194
+ )
195
+
196
+ with patch("kiln_ai.adapters.langchain_adapters.Config.shared") as mock_config:
197
+ mock_config.return_value.bedrock_access_key = "test_access"
198
+ mock_config.return_value.bedrock_secret_key = "test_secret"
199
+ model = await langchain_model_from_provider(provider, "anthropic.claude-v2")
200
+ assert isinstance(model, ChatBedrockConverse)
201
+ assert os.environ.get("AWS_ACCESS_KEY_ID") == "test_access"
202
+ assert os.environ.get("AWS_SECRET_ACCESS_KEY") == "test_secret"
203
+
204
+
205
+ @pytest.mark.asyncio
206
+ async def test_langchain_model_from_provider_fireworks():
207
+ provider = KilnModelProvider(
208
+ name=ModelProviderName.fireworks_ai, provider_options={"model": "mixtral-8x7b"}
209
+ )
210
+
211
+ with patch("kiln_ai.adapters.langchain_adapters.Config.shared") as mock_config:
212
+ mock_config.return_value.fireworks_api_key = "test_key"
213
+ model = await langchain_model_from_provider(provider, "mixtral-8x7b")
214
+ assert isinstance(model, ChatFireworks)
215
+
216
+
217
+ @pytest.mark.asyncio
218
+ async def test_langchain_model_from_provider_ollama():
219
+ provider = KilnModelProvider(
220
+ name=ModelProviderName.ollama,
221
+ provider_options={"model": "llama2", "model_aliases": ["llama2-uncensored"]},
222
+ )
223
+
224
+ mock_connection = MagicMock()
225
+ with (
226
+ patch(
227
+ "kiln_ai.adapters.langchain_adapters.get_ollama_connection",
228
+ return_value=AsyncMock(return_value=mock_connection),
229
+ ),
230
+ patch(
231
+ "kiln_ai.adapters.langchain_adapters.ollama_model_installed",
232
+ return_value=True,
233
+ ),
234
+ patch(
235
+ "kiln_ai.adapters.langchain_adapters.ollama_base_url",
236
+ return_value="http://localhost:11434",
237
+ ),
238
+ ):
239
+ model = await langchain_model_from_provider(provider, "llama2")
240
+ assert isinstance(model, ChatOllama)
241
+ assert model.model == "llama2"
242
+
243
+
244
+ @pytest.mark.asyncio
245
+ async def test_langchain_model_from_provider_invalid():
246
+ provider = KilnModelProvider.model_construct(
247
+ name="invalid_provider", provider_options={}
248
+ )
249
+
250
+ with pytest.raises(ValueError, match="Invalid model or provider"):
251
+ await langchain_model_from_provider(provider, "test_model")
252
+
253
+
254
+ @pytest.mark.asyncio
255
+ async def test_langchain_adapter_model_caching(tmp_path):
256
+ task = build_test_task(tmp_path)
257
+ custom_model = ChatGroq(model="mixtral-8x7b", groq_api_key="test")
258
+
259
+ adapter = LangchainAdapter(kiln_task=task, custom_model=custom_model)
260
+
261
+ # First call should return the cached model
262
+ model1 = await adapter.model()
263
+ assert model1 is custom_model
264
+
265
+ # Second call should return the same cached instance
266
+ model2 = await adapter.model()
267
+ assert model2 is model1
268
+
269
+
270
+ @pytest.mark.asyncio
271
+ async def test_langchain_adapter_model_structured_output(tmp_path):
272
+ task = build_test_task(tmp_path)
273
+ task.output_json_schema = """
274
+ {
275
+ "type": "object",
276
+ "properties": {
277
+ "count": {"type": "integer"}
278
+ }
279
+ }
280
+ """
281
+
282
+ mock_model = MagicMock()
283
+ mock_model.with_structured_output = MagicMock(return_value="structured_model")
284
+
285
+ adapter = LangchainAdapter(
286
+ kiln_task=task, model_name="test_model", provider="test_provider"
287
+ )
288
+
289
+ with (
290
+ patch(
291
+ "kiln_ai.adapters.langchain_adapters.langchain_model_from",
292
+ AsyncMock(return_value=mock_model),
293
+ ),
294
+ patch(
295
+ "kiln_ai.adapters.langchain_adapters.get_structured_output_options",
296
+ AsyncMock(return_value={"option1": "value1"}),
297
+ ),
298
+ ):
299
+ model = await adapter.model()
300
+
301
+ # Verify the model was configured with structured output
302
+ mock_model.with_structured_output.assert_called_once_with(
303
+ {
304
+ "type": "object",
305
+ "properties": {"count": {"type": "integer"}},
306
+ "title": "task_response",
307
+ "description": "A response from the task",
308
+ },
309
+ include_raw=True,
310
+ option1="value1",
311
+ )
312
+ assert model == "structured_model"
313
+
314
+
315
+ @pytest.mark.asyncio
316
+ async def test_langchain_adapter_model_no_structured_output_support(tmp_path):
317
+ task = build_test_task(tmp_path)
318
+ task.output_json_schema = (
319
+ '{"type": "object", "properties": {"count": {"type": "integer"}}}'
320
+ )
321
+
322
+ mock_model = MagicMock()
323
+ # Remove with_structured_output method
324
+ del mock_model.with_structured_output
325
+
326
+ adapter = LangchainAdapter(
327
+ kiln_task=task, model_name="test_model", provider="test_provider"
328
+ )
329
+
330
+ with patch(
331
+ "kiln_ai.adapters.langchain_adapters.langchain_model_from",
332
+ AsyncMock(return_value=mock_model),
333
+ ):
334
+ with pytest.raises(ValueError, match="does not support structured output"):
335
+ await adapter.model()
@@ -1,14 +1,18 @@
1
- from unittest.mock import AsyncMock, patch
1
+ from unittest.mock import AsyncMock, Mock, patch
2
2
 
3
3
  import pytest
4
4
 
5
5
  from kiln_ai.adapters.ml_model_list import (
6
+ KilnModel,
6
7
  ModelName,
7
8
  ModelProviderName,
8
9
  )
9
10
  from kiln_ai.adapters.ollama_tools import OllamaConnection
10
11
  from kiln_ai.adapters.provider_tools import (
12
+ builtin_model_from,
11
13
  check_provider_warnings,
14
+ finetune_cache,
15
+ finetune_provider_model,
12
16
  get_model_and_provider,
13
17
  kiln_model_provider_from,
14
18
  provider_enabled,
@@ -16,6 +20,14 @@ from kiln_ai.adapters.provider_tools import (
16
20
  provider_options_for_custom_model,
17
21
  provider_warnings,
18
22
  )
23
+ from kiln_ai.datamodel import Finetune, Task
24
+
25
+
26
+ @pytest.fixture(autouse=True)
27
+ def clear_finetune_cache():
28
+ """Clear the finetune provider model cache before each test"""
29
+ finetune_cache.clear()
30
+ yield
19
31
 
20
32
 
21
33
  @pytest.fixture
@@ -24,6 +36,34 @@ def mock_config():
24
36
  yield mock
25
37
 
26
38
 
39
+ @pytest.fixture
40
+ def mock_project():
41
+ with patch("kiln_ai.adapters.provider_tools.project_from_id") as mock:
42
+ project = Mock()
43
+ project.path = "/fake/path"
44
+ mock.return_value = project
45
+ yield mock
46
+
47
+
48
+ @pytest.fixture
49
+ def mock_task():
50
+ with patch("kiln_ai.datamodel.Task.from_id_and_parent_path") as mock:
51
+ task = Mock(spec=Task)
52
+ task.path = "/fake/path/task"
53
+ mock.return_value = task
54
+ yield mock
55
+
56
+
57
+ @pytest.fixture
58
+ def mock_finetune():
59
+ with patch("kiln_ai.datamodel.Finetune.from_id_and_parent_path") as mock:
60
+ finetune = Mock(spec=Finetune)
61
+ finetune.provider = ModelProviderName.openai
62
+ finetune.fine_tune_model_id = "ft:gpt-3.5-turbo:custom:model-123"
63
+ mock.return_value = finetune
64
+ yield mock
65
+
66
+
27
67
  def test_check_provider_warnings_no_warning(mock_config):
28
68
  mock_config.return_value = "some_value"
29
69
 
@@ -103,6 +143,8 @@ def test_provider_name_from_id_case_sensitivity():
103
143
  (ModelProviderName.ollama, "Ollama"),
104
144
  (ModelProviderName.openai, "OpenAI"),
105
145
  (ModelProviderName.fireworks_ai, "Fireworks AI"),
146
+ (ModelProviderName.kiln_fine_tune, "Fine Tuned Models"),
147
+ (ModelProviderName.kiln_custom_registry, "Custom Models"),
106
148
  ],
107
149
  )
108
150
  def test_provider_name_from_id_parametrized(provider_id, expected_name):
@@ -310,3 +352,180 @@ def test_provider_options_for_custom_model_invalid_enum():
310
352
  """Test handling of invalid enum value"""
311
353
  with pytest.raises(ValueError):
312
354
  provider_options_for_custom_model("model_name", "invalid_enum_value")
355
+
356
+
357
+ @pytest.mark.asyncio
358
+ async def test_kiln_model_provider_from_custom_registry(mock_config):
359
+ # Mock config to pass provider warnings check
360
+ mock_config.return_value = "fake-api-key"
361
+
362
+ # Test with a custom registry model ID in format "provider::model_name"
363
+ provider = await kiln_model_provider_from(
364
+ "openai::gpt-4-turbo", ModelProviderName.kiln_custom_registry
365
+ )
366
+
367
+ assert provider.name == ModelProviderName.openai
368
+ assert provider.supports_structured_output is False
369
+ assert provider.supports_data_gen is False
370
+ assert provider.untested_model is True
371
+ assert provider.provider_options == {"model": "gpt-4-turbo"}
372
+
373
+
374
+ @pytest.mark.asyncio
375
+ async def test_builtin_model_from_invalid_model():
376
+ """Test that an invalid model name returns None"""
377
+ result = await builtin_model_from("non_existent_model")
378
+ assert result is None
379
+
380
+
381
+ @pytest.mark.asyncio
382
+ async def test_builtin_model_from_valid_model_default_provider(mock_config):
383
+ """Test getting a valid model with default provider"""
384
+ mock_config.return_value = "fake-api-key"
385
+
386
+ provider = await builtin_model_from(ModelName.phi_3_5)
387
+
388
+ assert provider is not None
389
+ assert provider.name == ModelProviderName.ollama
390
+ assert provider.provider_options["model"] == "phi3.5"
391
+
392
+
393
+ @pytest.mark.asyncio
394
+ async def test_builtin_model_from_valid_model_specific_provider(mock_config):
395
+ """Test getting a valid model with specific provider"""
396
+ mock_config.return_value = "fake-api-key"
397
+
398
+ provider = await builtin_model_from(
399
+ ModelName.llama_3_1_70b, provider_name=ModelProviderName.groq
400
+ )
401
+
402
+ assert provider is not None
403
+ assert provider.name == ModelProviderName.groq
404
+ assert provider.provider_options["model"] == "llama-3.1-70b-versatile"
405
+
406
+
407
+ @pytest.mark.asyncio
408
+ async def test_builtin_model_from_invalid_provider(mock_config):
409
+ """Test that requesting an invalid provider returns None"""
410
+ mock_config.return_value = "fake-api-key"
411
+
412
+ provider = await builtin_model_from(
413
+ ModelName.phi_3_5, provider_name="invalid_provider"
414
+ )
415
+
416
+ assert provider is None
417
+
418
+
419
+ @pytest.mark.asyncio
420
+ async def test_builtin_model_from_model_no_providers():
421
+ """Test handling of a model with no providers"""
422
+ with patch("kiln_ai.adapters.provider_tools.built_in_models") as mock_models:
423
+ # Create a mock model with no providers
424
+ mock_model = KilnModel(
425
+ name=ModelName.phi_3_5,
426
+ friendly_name="Test Model",
427
+ providers=[],
428
+ family="test_family",
429
+ )
430
+ mock_models.__iter__.return_value = [mock_model]
431
+
432
+ with pytest.raises(ValueError) as exc_info:
433
+ await builtin_model_from(ModelName.phi_3_5)
434
+
435
+ assert str(exc_info.value) == f"Model {ModelName.phi_3_5} has no providers"
436
+
437
+
438
+ @pytest.mark.asyncio
439
+ async def test_builtin_model_from_provider_warning_check(mock_config):
440
+ """Test that provider warnings are checked"""
441
+ # Make the config check fail
442
+ mock_config.return_value = None
443
+
444
+ with pytest.raises(ValueError) as exc_info:
445
+ await builtin_model_from(ModelName.llama_3_1_70b, ModelProviderName.groq)
446
+
447
+ assert provider_warnings[ModelProviderName.groq].message in str(exc_info.value)
448
+
449
+
450
+ def test_finetune_provider_model_success(mock_project, mock_task, mock_finetune):
451
+ """Test successful creation of a fine-tuned model provider"""
452
+ model_id = "project-123::task-456::finetune-789"
453
+
454
+ provider = finetune_provider_model(model_id)
455
+
456
+ assert provider.name == ModelProviderName.openai
457
+ assert provider.provider_options == {"model": "ft:gpt-3.5-turbo:custom:model-123"}
458
+
459
+ # Test cache
460
+ cached_provider = finetune_provider_model(model_id)
461
+ assert cached_provider is provider
462
+
463
+
464
+ def test_finetune_provider_model_invalid_id():
465
+ """Test handling of invalid model ID format"""
466
+ with pytest.raises(ValueError) as exc_info:
467
+ finetune_provider_model("invalid-id-format")
468
+ assert str(exc_info.value) == "Invalid fine tune ID: invalid-id-format"
469
+
470
+
471
+ def test_finetune_provider_model_project_not_found(mock_project):
472
+ """Test handling of non-existent project"""
473
+ mock_project.return_value = None
474
+
475
+ with pytest.raises(ValueError) as exc_info:
476
+ finetune_provider_model("project-123::task-456::finetune-789")
477
+ assert str(exc_info.value) == "Project project-123 not found"
478
+
479
+
480
+ def test_finetune_provider_model_task_not_found(mock_project, mock_task):
481
+ """Test handling of non-existent task"""
482
+ mock_task.return_value = None
483
+
484
+ with pytest.raises(ValueError) as exc_info:
485
+ finetune_provider_model("project-123::task-456::finetune-789")
486
+ assert str(exc_info.value) == "Task task-456 not found"
487
+
488
+
489
+ def test_finetune_provider_model_finetune_not_found(
490
+ mock_project, mock_task, mock_finetune
491
+ ):
492
+ """Test handling of non-existent fine-tune"""
493
+ mock_finetune.return_value = None
494
+
495
+ with pytest.raises(ValueError) as exc_info:
496
+ finetune_provider_model("project-123::task-456::finetune-789")
497
+ assert str(exc_info.value) == "Fine tune finetune-789 not found"
498
+
499
+
500
+ def test_finetune_provider_model_incomplete_finetune(
501
+ mock_project, mock_task, mock_finetune
502
+ ):
503
+ """Test handling of incomplete fine-tune"""
504
+ finetune = Mock(spec=Finetune)
505
+ finetune.fine_tune_model_id = None
506
+ mock_finetune.return_value = finetune
507
+
508
+ with pytest.raises(ValueError) as exc_info:
509
+ finetune_provider_model("project-123::task-456::finetune-789")
510
+ assert (
511
+ str(exc_info.value)
512
+ == "Fine tune finetune-789 not completed. Refresh it's status in the fine-tune tab."
513
+ )
514
+
515
+
516
+ def test_finetune_provider_model_fireworks_provider(
517
+ mock_project, mock_task, mock_finetune
518
+ ):
519
+ """Test creation of Fireworks AI provider with specific adapter options"""
520
+ finetune = Mock(spec=Finetune)
521
+ finetune.provider = ModelProviderName.fireworks_ai
522
+ finetune.fine_tune_model_id = "fireworks-model-123"
523
+ mock_finetune.return_value = finetune
524
+
525
+ provider = finetune_provider_model("project-123::task-456::finetune-789")
526
+
527
+ assert provider.name == ModelProviderName.fireworks_ai
528
+ assert provider.provider_options == {"model": "fireworks-model-123"}
529
+ assert provider.adapter_options == {
530
+ "langchain": {"with_structured_output_options": {"method": "json_mode"}}
531
+ }