kiln-ai 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (42) hide show
  1. kiln_ai/adapters/__init__.py +11 -1
  2. kiln_ai/adapters/adapter_registry.py +19 -0
  3. kiln_ai/adapters/data_gen/__init__.py +11 -0
  4. kiln_ai/adapters/data_gen/data_gen_task.py +69 -1
  5. kiln_ai/adapters/data_gen/test_data_gen_task.py +30 -21
  6. kiln_ai/adapters/fine_tune/__init__.py +14 -0
  7. kiln_ai/adapters/fine_tune/base_finetune.py +186 -0
  8. kiln_ai/adapters/fine_tune/dataset_formatter.py +187 -0
  9. kiln_ai/adapters/fine_tune/finetune_registry.py +11 -0
  10. kiln_ai/adapters/fine_tune/fireworks_finetune.py +308 -0
  11. kiln_ai/adapters/fine_tune/openai_finetune.py +205 -0
  12. kiln_ai/adapters/fine_tune/test_base_finetune.py +290 -0
  13. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +342 -0
  14. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +455 -0
  15. kiln_ai/adapters/fine_tune/test_openai_finetune.py +503 -0
  16. kiln_ai/adapters/langchain_adapters.py +103 -13
  17. kiln_ai/adapters/ml_model_list.py +218 -304
  18. kiln_ai/adapters/ollama_tools.py +114 -0
  19. kiln_ai/adapters/provider_tools.py +295 -0
  20. kiln_ai/adapters/repair/test_repair_task.py +6 -11
  21. kiln_ai/adapters/test_langchain_adapter.py +46 -18
  22. kiln_ai/adapters/test_ollama_tools.py +42 -0
  23. kiln_ai/adapters/test_prompt_adaptors.py +7 -5
  24. kiln_ai/adapters/test_provider_tools.py +312 -0
  25. kiln_ai/adapters/test_structured_output.py +22 -43
  26. kiln_ai/datamodel/__init__.py +235 -22
  27. kiln_ai/datamodel/basemodel.py +30 -0
  28. kiln_ai/datamodel/registry.py +31 -0
  29. kiln_ai/datamodel/test_basemodel.py +29 -1
  30. kiln_ai/datamodel/test_dataset_split.py +234 -0
  31. kiln_ai/datamodel/test_example_models.py +12 -0
  32. kiln_ai/datamodel/test_models.py +91 -1
  33. kiln_ai/datamodel/test_registry.py +96 -0
  34. kiln_ai/utils/config.py +9 -0
  35. kiln_ai/utils/name_generator.py +125 -0
  36. kiln_ai/utils/test_name_geneator.py +47 -0
  37. {kiln_ai-0.6.0.dist-info → kiln_ai-0.7.0.dist-info}/METADATA +4 -2
  38. kiln_ai-0.7.0.dist-info/RECORD +56 -0
  39. kiln_ai/adapters/test_ml_model_list.py +0 -181
  40. kiln_ai-0.6.0.dist-info/RECORD +0 -36
  41. {kiln_ai-0.6.0.dist-info → kiln_ai-0.7.0.dist-info}/WHEEL +0 -0
  42. {kiln_ai-0.6.0.dist-info → kiln_ai-0.7.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -1,20 +1,8 @@
1
- import os
2
- from dataclasses import dataclass
3
1
  from enum import Enum
4
- from os import getenv
5
- from typing import Any, Dict, List, NoReturn
2
+ from typing import Dict, List
6
3
 
7
- import httpx
8
- import requests
9
- from langchain_aws import ChatBedrockConverse
10
- from langchain_core.language_models.chat_models import BaseChatModel
11
- from langchain_groq import ChatGroq
12
- from langchain_ollama import ChatOllama
13
- from langchain_openai import ChatOpenAI
14
4
  from pydantic import BaseModel
15
5
 
16
- from ..utils.config import Config
17
-
18
6
  """
19
7
  Provides model configuration and management for various LLM providers and models.
20
8
  This module handles the integration with different AI model providers and their respective models,
@@ -32,6 +20,8 @@ class ModelProviderName(str, Enum):
32
20
  amazon_bedrock = "amazon_bedrock"
33
21
  ollama = "ollama"
34
22
  openrouter = "openrouter"
23
+ fireworks_ai = "fireworks_ai"
24
+ kiln_fine_tune = "kiln_fine_tune"
35
25
 
36
26
 
37
27
  class ModelFamily(str, Enum):
@@ -46,6 +36,8 @@ class ModelFamily(str, Enum):
46
36
  gemma = "gemma"
47
37
  gemini = "gemini"
48
38
  claude = "claude"
39
+ mixtral = "mixtral"
40
+ qwen = "qwen"
49
41
 
50
42
 
51
43
  # Where models have instruct and raw versions, instruct is default and raw is specified
@@ -58,6 +50,7 @@ class ModelName(str, Enum):
58
50
  llama_3_1_8b = "llama_3_1_8b"
59
51
  llama_3_1_70b = "llama_3_1_70b"
60
52
  llama_3_1_405b = "llama_3_1_405b"
53
+ llama_3_2_1b = "llama_3_2_1b"
61
54
  llama_3_2_3b = "llama_3_2_3b"
62
55
  llama_3_2_11b = "llama_3_2_11b"
63
56
  llama_3_2_90b = "llama_3_2_90b"
@@ -75,6 +68,9 @@ class ModelName(str, Enum):
75
68
  gemini_1_5_flash_8b = "gemini_1_5_flash_8b"
76
69
  gemini_1_5_pro = "gemini_1_5_pro"
77
70
  nemotron_70b = "nemotron_70b"
71
+ mixtral_8x7b = "mixtral_8x7b"
72
+ qwen_2p5_7b = "qwen_2p5_7b"
73
+ qwen_2p5_72b = "qwen_2p5_72b"
78
74
 
79
75
 
80
76
  class KilnModelProvider(BaseModel):
@@ -84,13 +80,20 @@ class KilnModelProvider(BaseModel):
84
80
  Attributes:
85
81
  name: The provider's identifier
86
82
  supports_structured_output: Whether the provider supports structured output formats
83
+ supports_data_gen: Whether the provider supports data generation
84
+ untested_model: Whether the model is untested (typically user added). The supports_ fields are not applicable.
85
+ provider_finetune_id: The finetune ID for the provider, if applicable
87
86
  provider_options: Additional provider-specific configuration options
87
+ adapter_options: Additional options specific to the adapter. Top level key should be adapter ID.
88
88
  """
89
89
 
90
90
  name: ModelProviderName
91
91
  supports_structured_output: bool = True
92
92
  supports_data_gen: bool = True
93
+ untested_model: bool = False
94
+ provider_finetune_id: str | None = None
93
95
  provider_options: Dict = {}
96
+ adapter_options: Dict = {}
94
97
 
95
98
 
96
99
  class KilnModel(BaseModel):
@@ -122,6 +125,7 @@ built_in_models: List[KilnModel] = [
122
125
  KilnModelProvider(
123
126
  name=ModelProviderName.openai,
124
127
  provider_options={"model": "gpt-4o-mini"},
128
+ provider_finetune_id="gpt-4o-mini-2024-07-18",
125
129
  ),
126
130
  KilnModelProvider(
127
131
  name=ModelProviderName.openrouter,
@@ -138,6 +142,7 @@ built_in_models: List[KilnModel] = [
138
142
  KilnModelProvider(
139
143
  name=ModelProviderName.openai,
140
144
  provider_options={"model": "gpt-4o"},
145
+ provider_finetune_id="gpt-4o-2024-08-06",
141
146
  ),
142
147
  KilnModelProvider(
143
148
  name=ModelProviderName.openrouter,
@@ -257,6 +262,15 @@ built_in_models: List[KilnModel] = [
257
262
  supports_data_gen=False,
258
263
  provider_options={"model": "meta-llama/llama-3.1-8b-instruct"},
259
264
  ),
265
+ KilnModelProvider(
266
+ name=ModelProviderName.fireworks_ai,
267
+ supports_structured_output=False,
268
+ supports_data_gen=False,
269
+ provider_finetune_id="accounts/fireworks/models/llama-v3p1-8b-instruct",
270
+ provider_options={
271
+ "model": "accounts/fireworks/models/llama-v3p1-8b-instruct"
272
+ },
273
+ ),
260
274
  ],
261
275
  ),
262
276
  # Llama 3.1 70b
@@ -271,7 +285,7 @@ built_in_models: List[KilnModel] = [
271
285
  ),
272
286
  KilnModelProvider(
273
287
  name=ModelProviderName.amazon_bedrock,
274
- # not sure how AWS manages to break this, but it's not working
288
+ # AWS 70b not working as well as the others.
275
289
  supports_structured_output=False,
276
290
  supports_data_gen=False,
277
291
  provider_options={
@@ -287,6 +301,13 @@ built_in_models: List[KilnModel] = [
287
301
  name=ModelProviderName.ollama,
288
302
  provider_options={"model": "llama3.1:70b"},
289
303
  ),
304
+ KilnModelProvider(
305
+ name=ModelProviderName.fireworks_ai,
306
+ provider_finetune_id="accounts/fireworks/models/llama-v3p1-70b-instruct",
307
+ provider_options={
308
+ "model": "accounts/fireworks/models/llama-v3p1-70b-instruct"
309
+ },
310
+ ),
290
311
  ],
291
312
  ),
292
313
  # Llama 3.1 405b
@@ -311,6 +332,13 @@ built_in_models: List[KilnModel] = [
311
332
  name=ModelProviderName.openrouter,
312
333
  provider_options={"model": "meta-llama/llama-3.1-405b-instruct"},
313
334
  ),
335
+ KilnModelProvider(
336
+ name=ModelProviderName.fireworks_ai,
337
+ # No finetune support. https://docs.fireworks.ai/fine-tuning/fine-tuning-models
338
+ provider_options={
339
+ "model": "accounts/fireworks/models/llama-v3p1-405b-instruct"
340
+ },
341
+ ),
314
342
  ],
315
343
  ),
316
344
  # Mistral Nemo
@@ -348,6 +376,35 @@ built_in_models: List[KilnModel] = [
348
376
  ),
349
377
  ],
350
378
  ),
379
+ # Llama 3.2 1B
380
+ KilnModel(
381
+ family=ModelFamily.llama,
382
+ name=ModelName.llama_3_2_1b,
383
+ friendly_name="Llama 3.2 1B",
384
+ providers=[
385
+ KilnModelProvider(
386
+ name=ModelProviderName.openrouter,
387
+ supports_structured_output=False,
388
+ supports_data_gen=False,
389
+ provider_options={"model": "meta-llama/llama-3.2-1b-instruct"},
390
+ ),
391
+ KilnModelProvider(
392
+ name=ModelProviderName.ollama,
393
+ supports_structured_output=False,
394
+ supports_data_gen=False,
395
+ provider_options={"model": "llama3.2:1b"},
396
+ ),
397
+ KilnModelProvider(
398
+ name=ModelProviderName.fireworks_ai,
399
+ provider_finetune_id="accounts/fireworks/models/llama-v3p2-1b-instruct",
400
+ supports_structured_output=False,
401
+ supports_data_gen=False,
402
+ provider_options={
403
+ "model": "accounts/fireworks/models/llama-v3p2-1b-instruct"
404
+ },
405
+ ),
406
+ ],
407
+ ),
351
408
  # Llama 3.2 3B
352
409
  KilnModel(
353
410
  family=ModelFamily.llama,
@@ -366,6 +423,15 @@ built_in_models: List[KilnModel] = [
366
423
  supports_data_gen=False,
367
424
  provider_options={"model": "llama3.2"},
368
425
  ),
426
+ KilnModelProvider(
427
+ name=ModelProviderName.fireworks_ai,
428
+ provider_finetune_id="accounts/fireworks/models/llama-v3p2-3b-instruct",
429
+ supports_structured_output=False,
430
+ supports_data_gen=False,
431
+ provider_options={
432
+ "model": "accounts/fireworks/models/llama-v3p2-3b-instruct"
433
+ },
434
+ ),
369
435
  ],
370
436
  ),
371
437
  # Llama 3.2 11B
@@ -376,9 +442,12 @@ built_in_models: List[KilnModel] = [
376
442
  providers=[
377
443
  KilnModelProvider(
378
444
  name=ModelProviderName.openrouter,
379
- supports_structured_output=False,
380
- supports_data_gen=False,
381
445
  provider_options={"model": "meta-llama/llama-3.2-11b-vision-instruct"},
446
+ adapter_options={
447
+ "langchain": {
448
+ "with_structured_output_options": {"method": "json_mode"}
449
+ }
450
+ },
382
451
  ),
383
452
  KilnModelProvider(
384
453
  name=ModelProviderName.ollama,
@@ -386,6 +455,18 @@ built_in_models: List[KilnModel] = [
386
455
  supports_data_gen=False,
387
456
  provider_options={"model": "llama3.2-vision"},
388
457
  ),
458
+ KilnModelProvider(
459
+ name=ModelProviderName.fireworks_ai,
460
+ # No finetune support. https://docs.fireworks.ai/fine-tuning/fine-tuning-models
461
+ provider_options={
462
+ "model": "accounts/fireworks/models/llama-v3p2-11b-vision-instruct"
463
+ },
464
+ adapter_options={
465
+ "langchain": {
466
+ "with_structured_output_options": {"method": "json_mode"}
467
+ }
468
+ },
469
+ ),
389
470
  ],
390
471
  ),
391
472
  # Llama 3.2 90B
@@ -396,16 +477,29 @@ built_in_models: List[KilnModel] = [
396
477
  providers=[
397
478
  KilnModelProvider(
398
479
  name=ModelProviderName.openrouter,
399
- supports_structured_output=False,
400
- supports_data_gen=False,
401
480
  provider_options={"model": "meta-llama/llama-3.2-90b-vision-instruct"},
481
+ adapter_options={
482
+ "langchain": {
483
+ "with_structured_output_options": {"method": "json_mode"}
484
+ }
485
+ },
402
486
  ),
403
487
  KilnModelProvider(
404
488
  name=ModelProviderName.ollama,
405
- supports_structured_output=False,
406
- supports_data_gen=False,
407
489
  provider_options={"model": "llama3.2-vision:90b"},
408
490
  ),
491
+ KilnModelProvider(
492
+ name=ModelProviderName.fireworks_ai,
493
+ # No finetune support. https://docs.fireworks.ai/fine-tuning/fine-tuning-models
494
+ provider_options={
495
+ "model": "accounts/fireworks/models/llama-v3p2-90b-vision-instruct"
496
+ },
497
+ adapter_options={
498
+ "langchain": {
499
+ "with_structured_output_options": {"method": "json_mode"}
500
+ }
501
+ },
502
+ ),
409
503
  ],
410
504
  ),
411
505
  # Phi 3.5
@@ -427,6 +521,15 @@ built_in_models: List[KilnModel] = [
427
521
  supports_data_gen=False,
428
522
  provider_options={"model": "microsoft/phi-3.5-mini-128k-instruct"},
429
523
  ),
524
+ KilnModelProvider(
525
+ name=ModelProviderName.fireworks_ai,
526
+ supports_structured_output=False,
527
+ supports_data_gen=False,
528
+ # No finetune support. https://docs.fireworks.ai/fine-tuning/fine-tuning-models
529
+ provider_options={
530
+ "model": "accounts/fireworks/models/phi-3-vision-128k-instruct"
531
+ },
532
+ ),
430
533
  ],
431
534
  ),
432
535
  # Gemma 2 2.6b
@@ -465,6 +568,7 @@ built_in_models: List[KilnModel] = [
465
568
  supports_data_gen=False,
466
569
  provider_options={"model": "google/gemma-2-9b-it"},
467
570
  ),
571
+ # fireworks AI errors - not allowing system role. Exclude until resolved.
468
572
  ],
469
573
  ),
470
574
  # Gemma 2 27b
@@ -488,290 +592,100 @@ built_in_models: List[KilnModel] = [
488
592
  ),
489
593
  ],
490
594
  ),
491
- ]
492
-
493
-
494
- def get_model_and_provider(
495
- model_name: str, provider_name: str
496
- ) -> tuple[KilnModel | None, KilnModelProvider | None]:
497
- model = next(filter(lambda m: m.name == model_name, built_in_models), None)
498
- if model is None:
499
- return None, None
500
- provider = next(filter(lambda p: p.name == provider_name, model.providers), None)
501
- # all or nothing
502
- if provider is None or model is None:
503
- return None, None
504
- return model, provider
505
-
506
-
507
- def provider_name_from_id(id: str) -> str:
508
- """
509
- Converts a provider ID to its human-readable name.
510
-
511
- Args:
512
- id: The provider identifier string
513
-
514
- Returns:
515
- The human-readable name of the provider
516
-
517
- Raises:
518
- ValueError: If the provider ID is invalid or unhandled
519
- """
520
- if id in ModelProviderName.__members__:
521
- enum_id = ModelProviderName(id)
522
- match enum_id:
523
- case ModelProviderName.amazon_bedrock:
524
- return "Amazon Bedrock"
525
- case ModelProviderName.openrouter:
526
- return "OpenRouter"
527
- case ModelProviderName.groq:
528
- return "Groq"
529
- case ModelProviderName.ollama:
530
- return "Ollama"
531
- case ModelProviderName.openai:
532
- return "OpenAI"
533
- case _:
534
- # triggers pyright warning if I miss a case
535
- raise_exhaustive_error(enum_id)
536
-
537
- return "Unknown provider: " + id
538
-
539
-
540
- def raise_exhaustive_error(value: NoReturn) -> NoReturn:
541
- raise ValueError(f"Unhandled enum value: {value}")
542
-
543
-
544
- @dataclass
545
- class ModelProviderWarning:
546
- required_config_keys: List[str]
547
- message: str
548
-
549
-
550
- provider_warnings: Dict[ModelProviderName, ModelProviderWarning] = {
551
- ModelProviderName.amazon_bedrock: ModelProviderWarning(
552
- required_config_keys=["bedrock_access_key", "bedrock_secret_key"],
553
- message="Attempted to use Amazon Bedrock without an access key and secret set. \nGet your keys from https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/overview",
554
- ),
555
- ModelProviderName.openrouter: ModelProviderWarning(
556
- required_config_keys=["open_router_api_key"],
557
- message="Attempted to use OpenRouter without an API key set. \nGet your API key from https://openrouter.ai/settings/keys",
595
+ # Mixtral 8x7B
596
+ KilnModel(
597
+ family=ModelFamily.mixtral,
598
+ name=ModelName.mixtral_8x7b,
599
+ friendly_name="Mixtral 8x7B",
600
+ providers=[
601
+ KilnModelProvider(
602
+ name=ModelProviderName.fireworks_ai,
603
+ provider_options={
604
+ "model": "accounts/fireworks/models/mixtral-8x7b-instruct-hf",
605
+ },
606
+ provider_finetune_id="accounts/fireworks/models/mixtral-8x7b-instruct-hf",
607
+ adapter_options={
608
+ "langchain": {
609
+ "with_structured_output_options": {"method": "json_mode"}
610
+ }
611
+ },
612
+ ),
613
+ KilnModelProvider(
614
+ name=ModelProviderName.openrouter,
615
+ provider_options={"model": "mistralai/mixtral-8x7b-instruct"},
616
+ adapter_options={
617
+ "langchain": {
618
+ "with_structured_output_options": {"method": "json_mode"}
619
+ }
620
+ },
621
+ ),
622
+ KilnModelProvider(
623
+ name=ModelProviderName.ollama,
624
+ supports_structured_output=False,
625
+ supports_data_gen=False,
626
+ provider_options={"model": "mixtral"},
627
+ ),
628
+ ],
558
629
  ),
559
- ModelProviderName.groq: ModelProviderWarning(
560
- required_config_keys=["groq_api_key"],
561
- message="Attempted to use Groq without an API key set. \nGet your API key from https://console.groq.com/keys",
630
+ # Qwen 2.5 7B
631
+ KilnModel(
632
+ family=ModelFamily.qwen,
633
+ name=ModelName.qwen_2p5_7b,
634
+ friendly_name="Qwen 2.5 7B",
635
+ providers=[
636
+ KilnModelProvider(
637
+ name=ModelProviderName.openrouter,
638
+ provider_options={"model": "qwen/qwen-2.5-7b-instruct"},
639
+ # Tool calls not supported. JSON doesn't error, but fails.
640
+ supports_structured_output=False,
641
+ supports_data_gen=False,
642
+ adapter_options={
643
+ "langchain": {
644
+ "with_structured_output_options": {"method": "json_mode"}
645
+ }
646
+ },
647
+ ),
648
+ KilnModelProvider(
649
+ name=ModelProviderName.ollama,
650
+ provider_options={"model": "qwen2.5"},
651
+ ),
652
+ ],
562
653
  ),
563
- ModelProviderName.openai: ModelProviderWarning(
564
- required_config_keys=["open_ai_api_key"],
565
- message="Attempted to use OpenAI without an API key set. \nGet your API key from https://platform.openai.com/account/api-keys",
654
+ # Qwen 2.5 72B
655
+ KilnModel(
656
+ family=ModelFamily.qwen,
657
+ name=ModelName.qwen_2p5_72b,
658
+ friendly_name="Qwen 2.5 72B",
659
+ providers=[
660
+ KilnModelProvider(
661
+ name=ModelProviderName.openrouter,
662
+ provider_options={"model": "qwen/qwen-2.5-72b-instruct"},
663
+ # Not consistent with structure data. Works sometimes but not often
664
+ supports_structured_output=False,
665
+ supports_data_gen=False,
666
+ adapter_options={
667
+ "langchain": {
668
+ "with_structured_output_options": {"method": "json_mode"}
669
+ }
670
+ },
671
+ ),
672
+ KilnModelProvider(
673
+ name=ModelProviderName.ollama,
674
+ provider_options={"model": "qwen2.5:72b"},
675
+ ),
676
+ KilnModelProvider(
677
+ name=ModelProviderName.fireworks_ai,
678
+ provider_options={
679
+ "model": "accounts/fireworks/models/qwen2p5-72b-instruct"
680
+ },
681
+ # Fireworks will start tuning, but it never finishes.
682
+ # provider_finetune_id="accounts/fireworks/models/qwen2p5-72b-instruct",
683
+ adapter_options={
684
+ "langchain": {
685
+ "with_structured_output_options": {"method": "json_mode"}
686
+ }
687
+ },
688
+ ),
689
+ ],
566
690
  ),
567
- }
568
-
569
-
570
- def get_config_value(key: str):
571
- try:
572
- return Config.shared().__getattr__(key)
573
- except AttributeError:
574
- return None
575
-
576
-
577
- def check_provider_warnings(provider_name: ModelProviderName):
578
- """
579
- Validates that required configuration is present for a given provider.
580
-
581
- Args:
582
- provider_name: The provider to check
583
-
584
- Raises:
585
- ValueError: If required configuration keys are missing
586
- """
587
- warning_check = provider_warnings.get(provider_name)
588
- if warning_check is None:
589
- return
590
- for key in warning_check.required_config_keys:
591
- if get_config_value(key) is None:
592
- raise ValueError(warning_check.message)
593
-
594
-
595
- async def langchain_model_from(
596
- name: str, provider_name: str | None = None
597
- ) -> BaseChatModel:
598
- """
599
- Creates a LangChain chat model instance for the specified model and provider.
600
-
601
- Args:
602
- name: The name of the model to instantiate
603
- provider_name: Optional specific provider to use (defaults to first available)
604
-
605
- Returns:
606
- A configured LangChain chat model instance
607
-
608
- Raises:
609
- ValueError: If the model/provider combination is invalid or misconfigured
610
- """
611
- if name not in ModelName.__members__:
612
- raise ValueError(f"Invalid name: {name}")
613
-
614
- # Select the model from built_in_models using the name
615
- model = next(filter(lambda m: m.name == name, built_in_models))
616
- if model is None:
617
- raise ValueError(f"Model {name} not found")
618
-
619
- # If a provider is provided, select the provider from the model's provider_config
620
- provider: KilnModelProvider | None = None
621
- if model.providers is None or len(model.providers) == 0:
622
- raise ValueError(f"Model {name} has no providers")
623
- elif provider_name is None:
624
- # TODO: priority order
625
- provider = model.providers[0]
626
- else:
627
- provider = next(
628
- filter(lambda p: p.name == provider_name, model.providers), None
629
- )
630
- if provider is None:
631
- raise ValueError(f"Provider {provider_name} not found for model {name}")
632
-
633
- check_provider_warnings(provider.name)
634
-
635
- if provider.name == ModelProviderName.openai:
636
- api_key = Config.shared().open_ai_api_key
637
- return ChatOpenAI(**provider.provider_options, openai_api_key=api_key) # type: ignore[arg-type]
638
- elif provider.name == ModelProviderName.groq:
639
- api_key = Config.shared().groq_api_key
640
- if api_key is None:
641
- raise ValueError(
642
- "Attempted to use Groq without an API key set. "
643
- "Get your API key from https://console.groq.com/keys"
644
- )
645
- return ChatGroq(**provider.provider_options, groq_api_key=api_key) # type: ignore[arg-type]
646
- elif provider.name == ModelProviderName.amazon_bedrock:
647
- api_key = Config.shared().bedrock_access_key
648
- secret_key = Config.shared().bedrock_secret_key
649
- # langchain doesn't allow passing these, so ugly hack to set env vars
650
- os.environ["AWS_ACCESS_KEY_ID"] = api_key
651
- os.environ["AWS_SECRET_ACCESS_KEY"] = secret_key
652
- return ChatBedrockConverse(
653
- **provider.provider_options,
654
- )
655
- elif provider.name == ModelProviderName.ollama:
656
- # Ollama model naming is pretty flexible. We try a few versions of the model name
657
- potential_model_names = []
658
- if "model" in provider.provider_options:
659
- potential_model_names.append(provider.provider_options["model"])
660
- if "model_aliases" in provider.provider_options:
661
- potential_model_names.extend(provider.provider_options["model_aliases"])
662
-
663
- # Get the list of models Ollama supports
664
- ollama_connection = await get_ollama_connection()
665
- if ollama_connection is None:
666
- raise ValueError("Failed to connect to Ollama. Ensure Ollama is running.")
667
-
668
- for model_name in potential_model_names:
669
- if ollama_model_supported(ollama_connection, model_name):
670
- return ChatOllama(model=model_name, base_url=ollama_base_url())
671
-
672
- raise ValueError(f"Model {name} not installed on Ollama")
673
- elif provider.name == ModelProviderName.openrouter:
674
- api_key = Config.shared().open_router_api_key
675
- base_url = getenv("OPENROUTER_BASE_URL") or "https://openrouter.ai/api/v1"
676
- return ChatOpenAI(
677
- **provider.provider_options,
678
- openai_api_key=api_key, # type: ignore[arg-type]
679
- openai_api_base=base_url, # type: ignore[arg-type]
680
- default_headers={
681
- "HTTP-Referer": "https://getkiln.ai/openrouter",
682
- "X-Title": "KilnAI",
683
- },
684
- )
685
- else:
686
- raise ValueError(f"Invalid model or provider: {name} - {provider_name}")
687
-
688
-
689
- def ollama_base_url() -> str:
690
- """
691
- Gets the base URL for Ollama API connections.
692
-
693
- Returns:
694
- The base URL to use for Ollama API calls, using environment variable if set
695
- or falling back to localhost default
696
- """
697
- env_base_url = os.getenv("OLLAMA_BASE_URL")
698
- if env_base_url is not None:
699
- return env_base_url
700
- return "http://localhost:11434"
701
-
702
-
703
- async def ollama_online() -> bool:
704
- """
705
- Checks if the Ollama service is available and responding.
706
-
707
- Returns:
708
- True if Ollama is available and responding, False otherwise
709
- """
710
- try:
711
- httpx.get(ollama_base_url() + "/api/tags")
712
- except httpx.RequestError:
713
- return False
714
- return True
715
-
716
-
717
- class OllamaConnection(BaseModel):
718
- message: str
719
- models: List[str]
720
-
721
-
722
- # Parse the Ollama /api/tags response
723
- def parse_ollama_tags(tags: Any) -> OllamaConnection | None:
724
- # Build a list of models we support for Ollama from the built-in model list
725
- supported_ollama_models = [
726
- provider.provider_options["model"]
727
- for model in built_in_models
728
- for provider in model.providers
729
- if provider.name == ModelProviderName.ollama
730
- ]
731
- # Append model_aliases to supported_ollama_models
732
- supported_ollama_models.extend(
733
- [
734
- alias
735
- for model in built_in_models
736
- for provider in model.providers
737
- for alias in provider.provider_options.get("model_aliases", [])
738
- ]
739
- )
740
-
741
- if "models" in tags:
742
- models = tags["models"]
743
- if isinstance(models, list):
744
- model_names = [model["model"] for model in models]
745
- available_supported_models = [
746
- model
747
- for model in model_names
748
- if model in supported_ollama_models
749
- or model in [f"{m}:latest" for m in supported_ollama_models]
750
- ]
751
- if available_supported_models:
752
- return OllamaConnection(
753
- message="Ollama connected",
754
- models=available_supported_models,
755
- )
756
-
757
- return OllamaConnection(
758
- message="Ollama is running, but no supported models are installed. Install one or more supported model, like 'ollama pull phi3.5'.",
759
- models=[],
760
- )
761
-
762
-
763
- async def get_ollama_connection() -> OllamaConnection | None:
764
- """
765
- Gets the connection status for Ollama.
766
- """
767
- try:
768
- tags = requests.get(ollama_base_url() + "/api/tags", timeout=5).json()
769
-
770
- except Exception:
771
- return None
772
-
773
- return parse_ollama_tags(tags)
774
-
775
-
776
- def ollama_model_supported(conn: OllamaConnection, model_name: str) -> bool:
777
- return model_name in conn.models or f"{model_name}:latest" in conn.models
691
+ ]